Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
5e0bc289
Unverified
Commit
5e0bc289
authored
Oct 04, 2024
by
Baber Abbasi
Committed by
GitHub
Oct 04, 2024
Browse files
fix tests (#2380)
parent
cb069004
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1727 additions
and
1281 deletions
+1727
-1281
.pre-commit-config.yaml
.pre-commit-config.yaml
+2
-2
examples/lm-eval-overview.ipynb
examples/lm-eval-overview.ipynb
+1204
-1202
examples/visualize-wandb.ipynb
examples/visualize-wandb.ipynb
+2
-0
lm_eval/api/task.py
lm_eval/api/task.py
+1
-1
lm_eval/tasks/galician_bench/flores_gl/create_yamls_flores_gl.py
.../tasks/galician_bench/flores_gl/create_yamls_flores_gl.py
+255
-37
lm_eval/tasks/portuguese_bench/flores_pt/create_yamls_flores_pt.py
...asks/portuguese_bench/flores_pt/create_yamls_flores_pt.py
+254
-36
lm_eval/tasks/spanish_bench/flores_es/create_yamls_flores_es.py
...l/tasks/spanish_bench/flores_es/create_yamls_flores_es.py
+3
-2
tests/models/test_huggingface.py
tests/models/test_huggingface.py
+6
-1
No files found.
.pre-commit-config.yaml
View file @
5e0bc289
...
...
@@ -2,7 +2,7 @@
exclude
:
^tests/testdata/
repos
:
-
repo
:
https://github.com/pre-commit/pre-commit-hooks
rev
:
v4.
5
.0
rev
:
v4.
6
.0
hooks
:
-
id
:
check-added-large-files
-
id
:
check-ast
...
...
@@ -29,7 +29,7 @@ repos:
-
id
:
mixed-line-ending
args
:
[
--fix=lf
]
-
repo
:
https://github.com/astral-sh/ruff-pre-commit
rev
:
v0.
4
.8
rev
:
v0.
6
.8
hooks
:
# Run the linter.
-
id
:
ruff
...
...
examples/lm-eval-overview.ipynb
View file @
5e0bc289
This diff is collapsed.
Click to expand it.
examples/visualize-wandb.ipynb
View file @
5e0bc289
...
...
@@ -68,6 +68,7 @@
"source": [
"import wandb\n",
"\n",
"\n",
"wandb.login()"
]
},
...
...
@@ -130,6 +131,7 @@
"import lm_eval\n",
"from lm_eval.loggers import WandbLogger\n",
"\n",
"\n",
"results = lm_eval.simple_evaluate(\n",
" model=\"hf\",\n",
" model_args=\"pretrained=microsoft/phi-2,trust_remote_code=True\",\n",
...
...
lm_eval/api/task.py
View file @
5e0bc289
...
...
@@ -1511,7 +1511,7 @@ class ConfigurableTask(Task):
# we expect multiple_targets to be a list.
elif
self
.
multiple_target
:
gold
=
list
(
gold
)
elif
type
(
gold
)
!=
type
(
result
):
elif
type
(
gold
)
is
type
(
result
):
# cast gold to the same type as result
gold
=
type
(
result
)(
gold
)
...
...
lm_eval/tasks/galician_bench/flores_gl/create_yamls_flores_gl.py
View file @
5e0bc289
# ruff: noqa: E731, E741
"""
Script to generate task YAMLs for the FLORES-200 dataset.
Based on `tasks/translation/utils.py`.
"""
import
argparse
import
itertools
import
yaml
from
langcodes
import
*
from
itertools
import
*
from
langcodes
import
Language
# utils
flatten
=
lambda
l
:
list
(
itertools
.
chain
(
*
l
))
# constants
_LANGUAGES
=
[
"ace_Arab"
,
"bam_Latn"
,
"dzo_Tibt"
,
"hin_Deva"
,
"khm_Khmr"
,
"mag_Deva"
,
"pap_Latn"
,
"sot_Latn"
,
"tur_Latn"
,
"ace_Latn"
,
"ban_Latn"
,
"ell_Grek"
,
"hne_Deva"
,
"kik_Latn"
,
"mai_Deva"
,
"pbt_Arab"
,
"spa_Latn"
,
"twi_Latn"
,
"acm_Arab"
,
"bel_Cyrl"
,
"eng_Latn"
,
"hrv_Latn"
,
"kin_Latn"
,
"mal_Mlym"
,
"pes_Arab"
,
"srd_Latn"
,
"tzm_Tfng"
,
"acq_Arab"
,
"bem_Latn"
,
"epo_Latn"
,
"hun_Latn"
,
"kir_Cyrl"
,
"mar_Deva"
,
"plt_Latn"
,
"srp_Cyrl"
,
"uig_Arab"
,
"aeb_Arab"
,
"ben_Beng"
,
"est_Latn"
,
"hye_Armn"
,
"kmb_Latn"
,
"min_Arab"
,
"pol_Latn"
,
"ssw_Latn"
,
"ukr_Cyrl"
,
"afr_Latn"
,
"bho_Deva"
,
"eus_Latn"
,
"ibo_Latn"
,
"kmr_Latn"
,
"min_Latn"
,
"por_Latn"
,
"sun_Latn"
,
"umb_Latn"
,
"ajp_Arab"
,
"bjn_Arab"
,
"ewe_Latn"
,
"ilo_Latn"
,
"knc_Arab"
,
"mkd_Cyrl"
,
"prs_Arab"
,
"swe_Latn"
,
"urd_Arab"
,
"aka_Latn"
,
"bjn_Latn"
,
"fao_Latn"
,
"ind_Latn"
,
"knc_Latn"
,
"mlt_Latn"
,
"quy_Latn"
,
"swh_Latn"
,
"uzn_Latn"
,
"als_Latn"
,
"bod_Tibt"
,
"fij_Latn"
,
"isl_Latn"
,
"kon_Latn"
,
"mni_Beng"
,
"ron_Latn"
,
"szl_Latn"
,
"vec_Latn"
,
"amh_Ethi"
,
"bos_Latn"
,
"fin_Latn"
,
"ita_Latn"
,
"kor_Hang"
,
"mos_Latn"
,
"run_Latn"
,
"tam_Taml"
,
"vie_Latn"
,
"apc_Arab"
,
"bug_Latn"
,
"fon_Latn"
,
"jav_Latn"
,
"lao_Laoo"
,
"mri_Latn"
,
"rus_Cyrl"
,
"taq_Latn"
,
"war_Latn"
,
"arb_Arab"
,
"bul_Cyrl"
,
"fra_Latn"
,
"jpn_Jpan"
,
"lij_Latn"
,
"mya_Mymr"
,
"sag_Latn"
,
"taq_Tfng"
,
"wol_Latn"
,
"arb_Latn"
,
"cat_Latn"
,
"fur_Latn"
,
"kab_Latn"
,
"lim_Latn"
,
"nld_Latn"
,
"san_Deva"
,
"tat_Cyrl"
,
"xho_Latn"
,
"ars_Arab"
,
"ceb_Latn"
,
"fuv_Latn"
,
"kac_Latn"
,
"lin_Latn"
,
"nno_Latn"
,
"sat_Olck"
,
"tel_Telu"
,
"ydd_Hebr"
,
"ary_Arab"
,
"ces_Latn"
,
"gaz_Latn"
,
"kam_Latn"
,
"lit_Latn"
,
"nob_Latn"
,
"scn_Latn"
,
"tgk_Cyrl"
,
"yor_Latn"
,
"arz_Arab"
,
"cjk_Latn"
,
"gla_Latn"
,
"kan_Knda"
,
"lmo_Latn"
,
"npi_Deva"
,
"shn_Mymr"
,
"tgl_Latn"
,
"yue_Hant"
,
"asm_Beng"
,
"ckb_Arab"
,
"gle_Latn"
,
"kas_Arab"
,
"ltg_Latn"
,
"nso_Latn"
,
"sin_Sinh"
,
"tha_Thai"
,
"zho_Hans"
,
"ast_Latn"
,
"crh_Latn"
,
"glg_Latn"
,
"kas_Deva"
,
"ltz_Latn"
,
"nus_Latn"
,
"slk_Latn"
,
"tir_Ethi"
,
"zho_Hant"
,
"awa_Deva"
,
"cym_Latn"
,
"grn_Latn"
,
"kat_Geor"
,
"lua_Latn"
,
"nya_Latn"
,
"slv_Latn"
,
"tpi_Latn"
,
"zsm_Latn"
,
"ayr_Latn"
,
"dan_Latn"
,
"guj_Gujr"
,
"kaz_Cyrl"
,
"lug_Latn"
,
"oci_Latn"
,
"smo_Latn"
,
"tsn_Latn"
,
"zul_Latn"
,
"azb_Arab"
,
"deu_Latn"
,
"hat_Latn"
,
"kbp_Latn"
,
"luo_Latn"
,
"ory_Orya"
,
"sna_Latn"
,
"tso_Latn"
,
"azj_Latn"
,
"dik_Latn"
,
"hau_Latn"
,
"kea_Latn"
,
"lus_Latn"
,
"pag_Latn"
,
"snd_Arab"
,
"tuk_Latn"
,
"bak_Cyrl"
,
"dyu_Latn"
,
"heb_Hebr"
,
"khk_Cyrl"
,
"lvs_Latn"
,
"pan_Guru"
,
"som_Latn"
,
"tum_Latn"
"ace_Arab"
,
"bam_Latn"
,
"dzo_Tibt"
,
"hin_Deva"
,
"khm_Khmr"
,
"mag_Deva"
,
"pap_Latn"
,
"sot_Latn"
,
"tur_Latn"
,
"ace_Latn"
,
"ban_Latn"
,
"ell_Grek"
,
"hne_Deva"
,
"kik_Latn"
,
"mai_Deva"
,
"pbt_Arab"
,
"spa_Latn"
,
"twi_Latn"
,
"acm_Arab"
,
"bel_Cyrl"
,
"eng_Latn"
,
"hrv_Latn"
,
"kin_Latn"
,
"mal_Mlym"
,
"pes_Arab"
,
"srd_Latn"
,
"tzm_Tfng"
,
"acq_Arab"
,
"bem_Latn"
,
"epo_Latn"
,
"hun_Latn"
,
"kir_Cyrl"
,
"mar_Deva"
,
"plt_Latn"
,
"srp_Cyrl"
,
"uig_Arab"
,
"aeb_Arab"
,
"ben_Beng"
,
"est_Latn"
,
"hye_Armn"
,
"kmb_Latn"
,
"min_Arab"
,
"pol_Latn"
,
"ssw_Latn"
,
"ukr_Cyrl"
,
"afr_Latn"
,
"bho_Deva"
,
"eus_Latn"
,
"ibo_Latn"
,
"kmr_Latn"
,
"min_Latn"
,
"por_Latn"
,
"sun_Latn"
,
"umb_Latn"
,
"ajp_Arab"
,
"bjn_Arab"
,
"ewe_Latn"
,
"ilo_Latn"
,
"knc_Arab"
,
"mkd_Cyrl"
,
"prs_Arab"
,
"swe_Latn"
,
"urd_Arab"
,
"aka_Latn"
,
"bjn_Latn"
,
"fao_Latn"
,
"ind_Latn"
,
"knc_Latn"
,
"mlt_Latn"
,
"quy_Latn"
,
"swh_Latn"
,
"uzn_Latn"
,
"als_Latn"
,
"bod_Tibt"
,
"fij_Latn"
,
"isl_Latn"
,
"kon_Latn"
,
"mni_Beng"
,
"ron_Latn"
,
"szl_Latn"
,
"vec_Latn"
,
"amh_Ethi"
,
"bos_Latn"
,
"fin_Latn"
,
"ita_Latn"
,
"kor_Hang"
,
"mos_Latn"
,
"run_Latn"
,
"tam_Taml"
,
"vie_Latn"
,
"apc_Arab"
,
"bug_Latn"
,
"fon_Latn"
,
"jav_Latn"
,
"lao_Laoo"
,
"mri_Latn"
,
"rus_Cyrl"
,
"taq_Latn"
,
"war_Latn"
,
"arb_Arab"
,
"bul_Cyrl"
,
"fra_Latn"
,
"jpn_Jpan"
,
"lij_Latn"
,
"mya_Mymr"
,
"sag_Latn"
,
"taq_Tfng"
,
"wol_Latn"
,
"arb_Latn"
,
"cat_Latn"
,
"fur_Latn"
,
"kab_Latn"
,
"lim_Latn"
,
"nld_Latn"
,
"san_Deva"
,
"tat_Cyrl"
,
"xho_Latn"
,
"ars_Arab"
,
"ceb_Latn"
,
"fuv_Latn"
,
"kac_Latn"
,
"lin_Latn"
,
"nno_Latn"
,
"sat_Olck"
,
"tel_Telu"
,
"ydd_Hebr"
,
"ary_Arab"
,
"ces_Latn"
,
"gaz_Latn"
,
"kam_Latn"
,
"lit_Latn"
,
"nob_Latn"
,
"scn_Latn"
,
"tgk_Cyrl"
,
"yor_Latn"
,
"arz_Arab"
,
"cjk_Latn"
,
"gla_Latn"
,
"kan_Knda"
,
"lmo_Latn"
,
"npi_Deva"
,
"shn_Mymr"
,
"tgl_Latn"
,
"yue_Hant"
,
"asm_Beng"
,
"ckb_Arab"
,
"gle_Latn"
,
"kas_Arab"
,
"ltg_Latn"
,
"nso_Latn"
,
"sin_Sinh"
,
"tha_Thai"
,
"zho_Hans"
,
"ast_Latn"
,
"crh_Latn"
,
"glg_Latn"
,
"kas_Deva"
,
"ltz_Latn"
,
"nus_Latn"
,
"slk_Latn"
,
"tir_Ethi"
,
"zho_Hant"
,
"awa_Deva"
,
"cym_Latn"
,
"grn_Latn"
,
"kat_Geor"
,
"lua_Latn"
,
"nya_Latn"
,
"slv_Latn"
,
"tpi_Latn"
,
"zsm_Latn"
,
"ayr_Latn"
,
"dan_Latn"
,
"guj_Gujr"
,
"kaz_Cyrl"
,
"lug_Latn"
,
"oci_Latn"
,
"smo_Latn"
,
"tsn_Latn"
,
"zul_Latn"
,
"azb_Arab"
,
"deu_Latn"
,
"hat_Latn"
,
"kbp_Latn"
,
"luo_Latn"
,
"ory_Orya"
,
"sna_Latn"
,
"tso_Latn"
,
"azj_Latn"
,
"dik_Latn"
,
"hau_Latn"
,
"kea_Latn"
,
"lus_Latn"
,
"pag_Latn"
,
"snd_Arab"
,
"tuk_Latn"
,
"bak_Cyrl"
,
"dyu_Latn"
,
"heb_Hebr"
,
"khk_Cyrl"
,
"lvs_Latn"
,
"pan_Guru"
,
"som_Latn"
,
"tum_Latn"
,
]
LANGUAGE_PAIRS
=
[
(
a
,
b
)
for
idx
,
a
in
enumerate
(
_LANGUAGES
)
for
b
in
_LANGUAGES
[
idx
+
1
:]
]
LANGUAGE_PAIRS
=
[(
a
,
b
)
for
idx
,
a
in
enumerate
(
_LANGUAGES
)
for
b
in
_LANGUAGES
[
idx
+
1
:]]
LANGUAGES_OF_INTEREST
=
[
"cat_Latn"
,
"spa_Latn"
,
"eng_Latn"
,
"glg_Latn"
,
"eus_Latn"
,
"ita_Latn"
,
"deu_Latn"
,
"por_Latn"
,
"fra_Latn"
]
LANGUAGES_OF_INTEREST
=
[
"cat_Latn"
,
"spa_Latn"
,
"eng_Latn"
,
"glg_Latn"
,
"eus_Latn"
,
"ita_Latn"
,
"deu_Latn"
,
"por_Latn"
,
"fra_Latn"
,
]
MAIN_LANG
=
"glg_Latn"
LANGUAGE_PAIRS
=
[(
a
,
b
)
for
(
a
,
b
)
in
LANGUAGE_PAIRS
if
a
in
LANGUAGES_OF_INTEREST
and
b
in
LANGUAGES_OF_INTEREST
and
MAIN_LANG
in
(
a
,
b
)]
LANGUAGE_PAIRS
=
[
(
a
,
b
)
for
(
a
,
b
)
in
LANGUAGE_PAIRS
if
a
in
LANGUAGES_OF_INTEREST
and
b
in
LANGUAGES_OF_INTEREST
and
MAIN_LANG
in
(
a
,
b
)
]
# auxiliary functions
code_to_language_name
=
lambda
code
:
Language
.
make
(
language
=
Language
.
get
(
code
)[
"language"
]).
display_name
()
code_to_language_name
=
lambda
code
:
Language
.
make
(
language
=
Language
.
get
(
code
)[
"language"
]
).
display_name
()
code_to_short_name
=
lambda
code
:
Language
.
get
(
code
)[
"language"
]
jinja_var
=
lambda
s
:
"{{"
+
s
+
"}}"
# wrapper to avoid having to escape { } in format strings
jinja_var
=
(
lambda
s
:
"{{"
+
s
+
"}}"
)
# wrapper to avoid having to escape { } in format strings
def
doc_to_text
(
src
:
str
,
tgt
:
str
)
->
str
:
src_name
,
tgt_name
=
map
(
code_to_language_name
,
[
src
,
tgt
])
...
...
@@ -56,12 +261,14 @@ def doc_to_text(src: str, tgt: str) -> str:
{
src_name
}
sentence:
{
jinja_var
(
'sentence_'
+
src
)
}
{
tgt_name
}
sentence:"""
def
doc_to_target
(
tgt
:
str
)
->
str
:
def
doc_to_target
(
tgt
:
str
)
->
str
:
return
f
"
{
jinja_var
(
'sentence_'
+
tgt
)
}
"
# main function
def
gen_lang_yamls
(
output_dir
:
str
,
overwrite
:
bool
)
->
None
:
"""
Generate a YAML file for each translation direction.
...
...
@@ -69,20 +276,23 @@ def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
err
=
[]
for
src
,
tgt
in
LANGUAGE_PAIRS
:
# do both translation directions for each lang pair
for
src
,
tgt
in
[(
src
,
tgt
),
(
tgt
,
src
)]:
lang_pair_name
=
f
"
{
code_to_short_name
(
src
)
}
-
{
code_to_short_name
(
tgt
)
}
"
yaml_file_name
=
f
"flores_
{
lang_pair_name
}
.yaml"
try
:
with
open
(
f
"
{
output_dir
}
/
{
yaml_file_name
}
"
,
"w"
if
overwrite
else
"x"
,
encoding
=
"utf-8"
)
as
outfile
:
with
open
(
f
"
{
output_dir
}
/
{
yaml_file_name
}
"
,
"w"
if
overwrite
else
"x"
,
encoding
=
"utf-8"
,
)
as
outfile
:
print
(
f
"Creating
{
yaml_file_name
}
..."
)
outfile
.
write
(
"# File generated by `create-yamls.py`
\n
"
)
yaml
.
dump
(
{
# "group": [f"{BENCH_NAME}_bench", f"{BENCH_NAME}_bench_flores"],
# "group": "flores_gl",
# "group": [f"{BENCH_NAME}_bench", f"{BENCH_NAME}_bench_flores"],
# "group": "flores_gl",
"include"
:
"_flores_common_yaml"
,
"task"
:
f
"flores_
{
lang_pair_name
}
"
,
"doc_to_text"
:
doc_to_text
(
src
,
tgt
),
...
...
@@ -105,11 +315,19 @@ def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
def
main
()
->
None
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--overwrite"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Overwrite files if they already exist"
)
parser
.
add_argument
(
"--output-dir"
,
default
=
"."
,
help
=
"Directory to write yaml files to"
)
parser
.
add_argument
(
"--overwrite"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Overwrite files if they already exist"
,
)
parser
.
add_argument
(
"--output-dir"
,
default
=
"."
,
help
=
"Directory to write yaml files to"
)
args
=
parser
.
parse_args
()
gen_lang_yamls
(
output_dir
=
args
.
output_dir
,
overwrite
=
args
.
overwrite
)
if
__name__
==
"__main__"
:
main
()
lm_eval/tasks/portuguese_bench/flores_pt/create_yamls_flores_pt.py
View file @
5e0bc289
# ruff: noqa: E731, E741
"""
Script to generate task YAMLs for the FLORES-200 dataset.
Based on `tasks/translation/utils.py`.
"""
import
argparse
import
itertools
import
yaml
from
langcodes
import
*
from
itertools
import
*
from
langcodes
import
Language
# utils
flatten
=
lambda
l
:
list
(
itertools
.
chain
(
*
l
))
# constants
_LANGUAGES
=
[
"ace_Arab"
,
"bam_Latn"
,
"dzo_Tibt"
,
"hin_Deva"
,
"khm_Khmr"
,
"mag_Deva"
,
"pap_Latn"
,
"sot_Latn"
,
"tur_Latn"
,
"ace_Latn"
,
"ban_Latn"
,
"ell_Grek"
,
"hne_Deva"
,
"kik_Latn"
,
"mai_Deva"
,
"pbt_Arab"
,
"spa_Latn"
,
"twi_Latn"
,
"acm_Arab"
,
"bel_Cyrl"
,
"eng_Latn"
,
"hrv_Latn"
,
"kin_Latn"
,
"mal_Mlym"
,
"pes_Arab"
,
"srd_Latn"
,
"tzm_Tfng"
,
"acq_Arab"
,
"bem_Latn"
,
"epo_Latn"
,
"hun_Latn"
,
"kir_Cyrl"
,
"mar_Deva"
,
"plt_Latn"
,
"srp_Cyrl"
,
"uig_Arab"
,
"aeb_Arab"
,
"ben_Beng"
,
"est_Latn"
,
"hye_Armn"
,
"kmb_Latn"
,
"min_Arab"
,
"pol_Latn"
,
"ssw_Latn"
,
"ukr_Cyrl"
,
"afr_Latn"
,
"bho_Deva"
,
"eus_Latn"
,
"ibo_Latn"
,
"kmr_Latn"
,
"min_Latn"
,
"por_Latn"
,
"sun_Latn"
,
"umb_Latn"
,
"ajp_Arab"
,
"bjn_Arab"
,
"ewe_Latn"
,
"ilo_Latn"
,
"knc_Arab"
,
"mkd_Cyrl"
,
"prs_Arab"
,
"swe_Latn"
,
"urd_Arab"
,
"aka_Latn"
,
"bjn_Latn"
,
"fao_Latn"
,
"ind_Latn"
,
"knc_Latn"
,
"mlt_Latn"
,
"quy_Latn"
,
"swh_Latn"
,
"uzn_Latn"
,
"als_Latn"
,
"bod_Tibt"
,
"fij_Latn"
,
"isl_Latn"
,
"kon_Latn"
,
"mni_Beng"
,
"ron_Latn"
,
"szl_Latn"
,
"vec_Latn"
,
"amh_Ethi"
,
"bos_Latn"
,
"fin_Latn"
,
"ita_Latn"
,
"kor_Hang"
,
"mos_Latn"
,
"run_Latn"
,
"tam_Taml"
,
"vie_Latn"
,
"apc_Arab"
,
"bug_Latn"
,
"fon_Latn"
,
"jav_Latn"
,
"lao_Laoo"
,
"mri_Latn"
,
"rus_Cyrl"
,
"taq_Latn"
,
"war_Latn"
,
"arb_Arab"
,
"bul_Cyrl"
,
"fra_Latn"
,
"jpn_Jpan"
,
"lij_Latn"
,
"mya_Mymr"
,
"sag_Latn"
,
"taq_Tfng"
,
"wol_Latn"
,
"arb_Latn"
,
"cat_Latn"
,
"fur_Latn"
,
"kab_Latn"
,
"lim_Latn"
,
"nld_Latn"
,
"san_Deva"
,
"tat_Cyrl"
,
"xho_Latn"
,
"ars_Arab"
,
"ceb_Latn"
,
"fuv_Latn"
,
"kac_Latn"
,
"lin_Latn"
,
"nno_Latn"
,
"sat_Olck"
,
"tel_Telu"
,
"ydd_Hebr"
,
"ary_Arab"
,
"ces_Latn"
,
"gaz_Latn"
,
"kam_Latn"
,
"lit_Latn"
,
"nob_Latn"
,
"scn_Latn"
,
"tgk_Cyrl"
,
"yor_Latn"
,
"arz_Arab"
,
"cjk_Latn"
,
"gla_Latn"
,
"kan_Knda"
,
"lmo_Latn"
,
"npi_Deva"
,
"shn_Mymr"
,
"tgl_Latn"
,
"yue_Hant"
,
"asm_Beng"
,
"ckb_Arab"
,
"gle_Latn"
,
"kas_Arab"
,
"ltg_Latn"
,
"nso_Latn"
,
"sin_Sinh"
,
"tha_Thai"
,
"zho_Hans"
,
"ast_Latn"
,
"crh_Latn"
,
"glg_Latn"
,
"kas_Deva"
,
"ltz_Latn"
,
"nus_Latn"
,
"slk_Latn"
,
"tir_Ethi"
,
"zho_Hant"
,
"awa_Deva"
,
"cym_Latn"
,
"grn_Latn"
,
"kat_Geor"
,
"lua_Latn"
,
"nya_Latn"
,
"slv_Latn"
,
"tpi_Latn"
,
"zsm_Latn"
,
"ayr_Latn"
,
"dan_Latn"
,
"guj_Gujr"
,
"kaz_Cyrl"
,
"lug_Latn"
,
"oci_Latn"
,
"smo_Latn"
,
"tsn_Latn"
,
"zul_Latn"
,
"azb_Arab"
,
"deu_Latn"
,
"hat_Latn"
,
"kbp_Latn"
,
"luo_Latn"
,
"ory_Orya"
,
"sna_Latn"
,
"tso_Latn"
,
"azj_Latn"
,
"dik_Latn"
,
"hau_Latn"
,
"kea_Latn"
,
"lus_Latn"
,
"pag_Latn"
,
"snd_Arab"
,
"tuk_Latn"
,
"bak_Cyrl"
,
"dyu_Latn"
,
"heb_Hebr"
,
"khk_Cyrl"
,
"lvs_Latn"
,
"pan_Guru"
,
"som_Latn"
,
"tum_Latn"
"ace_Arab"
,
"bam_Latn"
,
"dzo_Tibt"
,
"hin_Deva"
,
"khm_Khmr"
,
"mag_Deva"
,
"pap_Latn"
,
"sot_Latn"
,
"tur_Latn"
,
"ace_Latn"
,
"ban_Latn"
,
"ell_Grek"
,
"hne_Deva"
,
"kik_Latn"
,
"mai_Deva"
,
"pbt_Arab"
,
"spa_Latn"
,
"twi_Latn"
,
"acm_Arab"
,
"bel_Cyrl"
,
"eng_Latn"
,
"hrv_Latn"
,
"kin_Latn"
,
"mal_Mlym"
,
"pes_Arab"
,
"srd_Latn"
,
"tzm_Tfng"
,
"acq_Arab"
,
"bem_Latn"
,
"epo_Latn"
,
"hun_Latn"
,
"kir_Cyrl"
,
"mar_Deva"
,
"plt_Latn"
,
"srp_Cyrl"
,
"uig_Arab"
,
"aeb_Arab"
,
"ben_Beng"
,
"est_Latn"
,
"hye_Armn"
,
"kmb_Latn"
,
"min_Arab"
,
"pol_Latn"
,
"ssw_Latn"
,
"ukr_Cyrl"
,
"afr_Latn"
,
"bho_Deva"
,
"eus_Latn"
,
"ibo_Latn"
,
"kmr_Latn"
,
"min_Latn"
,
"por_Latn"
,
"sun_Latn"
,
"umb_Latn"
,
"ajp_Arab"
,
"bjn_Arab"
,
"ewe_Latn"
,
"ilo_Latn"
,
"knc_Arab"
,
"mkd_Cyrl"
,
"prs_Arab"
,
"swe_Latn"
,
"urd_Arab"
,
"aka_Latn"
,
"bjn_Latn"
,
"fao_Latn"
,
"ind_Latn"
,
"knc_Latn"
,
"mlt_Latn"
,
"quy_Latn"
,
"swh_Latn"
,
"uzn_Latn"
,
"als_Latn"
,
"bod_Tibt"
,
"fij_Latn"
,
"isl_Latn"
,
"kon_Latn"
,
"mni_Beng"
,
"ron_Latn"
,
"szl_Latn"
,
"vec_Latn"
,
"amh_Ethi"
,
"bos_Latn"
,
"fin_Latn"
,
"ita_Latn"
,
"kor_Hang"
,
"mos_Latn"
,
"run_Latn"
,
"tam_Taml"
,
"vie_Latn"
,
"apc_Arab"
,
"bug_Latn"
,
"fon_Latn"
,
"jav_Latn"
,
"lao_Laoo"
,
"mri_Latn"
,
"rus_Cyrl"
,
"taq_Latn"
,
"war_Latn"
,
"arb_Arab"
,
"bul_Cyrl"
,
"fra_Latn"
,
"jpn_Jpan"
,
"lij_Latn"
,
"mya_Mymr"
,
"sag_Latn"
,
"taq_Tfng"
,
"wol_Latn"
,
"arb_Latn"
,
"cat_Latn"
,
"fur_Latn"
,
"kab_Latn"
,
"lim_Latn"
,
"nld_Latn"
,
"san_Deva"
,
"tat_Cyrl"
,
"xho_Latn"
,
"ars_Arab"
,
"ceb_Latn"
,
"fuv_Latn"
,
"kac_Latn"
,
"lin_Latn"
,
"nno_Latn"
,
"sat_Olck"
,
"tel_Telu"
,
"ydd_Hebr"
,
"ary_Arab"
,
"ces_Latn"
,
"gaz_Latn"
,
"kam_Latn"
,
"lit_Latn"
,
"nob_Latn"
,
"scn_Latn"
,
"tgk_Cyrl"
,
"yor_Latn"
,
"arz_Arab"
,
"cjk_Latn"
,
"gla_Latn"
,
"kan_Knda"
,
"lmo_Latn"
,
"npi_Deva"
,
"shn_Mymr"
,
"tgl_Latn"
,
"yue_Hant"
,
"asm_Beng"
,
"ckb_Arab"
,
"gle_Latn"
,
"kas_Arab"
,
"ltg_Latn"
,
"nso_Latn"
,
"sin_Sinh"
,
"tha_Thai"
,
"zho_Hans"
,
"ast_Latn"
,
"crh_Latn"
,
"glg_Latn"
,
"kas_Deva"
,
"ltz_Latn"
,
"nus_Latn"
,
"slk_Latn"
,
"tir_Ethi"
,
"zho_Hant"
,
"awa_Deva"
,
"cym_Latn"
,
"grn_Latn"
,
"kat_Geor"
,
"lua_Latn"
,
"nya_Latn"
,
"slv_Latn"
,
"tpi_Latn"
,
"zsm_Latn"
,
"ayr_Latn"
,
"dan_Latn"
,
"guj_Gujr"
,
"kaz_Cyrl"
,
"lug_Latn"
,
"oci_Latn"
,
"smo_Latn"
,
"tsn_Latn"
,
"zul_Latn"
,
"azb_Arab"
,
"deu_Latn"
,
"hat_Latn"
,
"kbp_Latn"
,
"luo_Latn"
,
"ory_Orya"
,
"sna_Latn"
,
"tso_Latn"
,
"azj_Latn"
,
"dik_Latn"
,
"hau_Latn"
,
"kea_Latn"
,
"lus_Latn"
,
"pag_Latn"
,
"snd_Arab"
,
"tuk_Latn"
,
"bak_Cyrl"
,
"dyu_Latn"
,
"heb_Hebr"
,
"khk_Cyrl"
,
"lvs_Latn"
,
"pan_Guru"
,
"som_Latn"
,
"tum_Latn"
,
]
LANGUAGE_PAIRS
=
[
(
a
,
b
)
for
idx
,
a
in
enumerate
(
_LANGUAGES
)
for
b
in
_LANGUAGES
[
idx
+
1
:]
]
LANGUAGE_PAIRS
=
[(
a
,
b
)
for
idx
,
a
in
enumerate
(
_LANGUAGES
)
for
b
in
_LANGUAGES
[
idx
+
1
:]]
LANGUAGES_OF_INTEREST
=
[
"cat_Latn"
,
"spa_Latn"
,
"eng_Latn"
,
"glg_Latn"
,
"eus_Latn"
,
"ita_Latn"
,
"deu_Latn"
,
"por_Latn"
,
"fra_Latn"
]
LANGUAGES_OF_INTEREST
=
[
"cat_Latn"
,
"spa_Latn"
,
"eng_Latn"
,
"glg_Latn"
,
"eus_Latn"
,
"ita_Latn"
,
"deu_Latn"
,
"por_Latn"
,
"fra_Latn"
,
]
MAIN_LANG
=
"por_Latn"
LANGUAGE_PAIRS
=
[(
a
,
b
)
for
(
a
,
b
)
in
LANGUAGE_PAIRS
if
a
in
LANGUAGES_OF_INTEREST
and
b
in
LANGUAGES_OF_INTEREST
and
MAIN_LANG
in
(
a
,
b
)]
LANGUAGE_PAIRS
=
[
(
a
,
b
)
for
(
a
,
b
)
in
LANGUAGE_PAIRS
if
a
in
LANGUAGES_OF_INTEREST
and
b
in
LANGUAGES_OF_INTEREST
and
MAIN_LANG
in
(
a
,
b
)
]
# auxiliary functions
code_to_language_name
=
lambda
code
:
Language
.
make
(
language
=
Language
.
get
(
code
)[
"language"
]).
display_name
()
code_to_language_name
=
lambda
code
:
Language
.
make
(
language
=
Language
.
get
(
code
)[
"language"
]
).
display_name
()
code_to_short_name
=
lambda
code
:
Language
.
get
(
code
)[
"language"
]
jinja_var
=
lambda
s
:
"{{"
+
s
+
"}}"
# wrapper to avoid having to escape { } in format strings
jinja_var
=
(
lambda
s
:
"{{"
+
s
+
"}}"
)
# wrapper to avoid having to escape { } in format strings
def
doc_to_text
(
src
:
str
,
tgt
:
str
)
->
str
:
src_name
,
tgt_name
=
map
(
code_to_language_name
,
[
src
,
tgt
])
...
...
@@ -56,12 +261,14 @@ def doc_to_text(src: str, tgt: str) -> str:
{
src_name
}
sentence:
{
jinja_var
(
'sentence_'
+
src
)
}
{
tgt_name
}
sentence:"""
def
doc_to_target
(
tgt
:
str
)
->
str
:
def
doc_to_target
(
tgt
:
str
)
->
str
:
return
f
"
{
jinja_var
(
'sentence_'
+
tgt
)
}
"
# main function
def
gen_lang_yamls
(
output_dir
:
str
,
overwrite
:
bool
)
->
None
:
"""
Generate a YAML file for each translation direction.
...
...
@@ -69,19 +276,22 @@ def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
err
=
[]
for
src
,
tgt
in
LANGUAGE_PAIRS
:
# do both translation directions for each lang pair
for
src
,
tgt
in
[(
src
,
tgt
),
(
tgt
,
src
)]:
lang_pair_name
=
f
"
{
code_to_short_name
(
src
)
}
-
{
code_to_short_name
(
tgt
)
}
"
yaml_file_name
=
f
"flores_
{
lang_pair_name
}
.yaml"
try
:
with
open
(
f
"
{
output_dir
}
/
{
yaml_file_name
}
"
,
"w"
if
overwrite
else
"x"
,
encoding
=
"utf-8"
)
as
outfile
:
with
open
(
f
"
{
output_dir
}
/
{
yaml_file_name
}
"
,
"w"
if
overwrite
else
"x"
,
encoding
=
"utf-8"
,
)
as
outfile
:
print
(
f
"Creating
{
yaml_file_name
}
..."
)
outfile
.
write
(
"# File generated by `create-yamls.py`
\n
"
)
yaml
.
dump
(
{
# "group": "flores_pt",
# "group": "flores_pt",
"include"
:
"_flores_common_yaml"
,
"task"
:
f
"flores_
{
lang_pair_name
}
"
,
"doc_to_text"
:
doc_to_text
(
src
,
tgt
),
...
...
@@ -104,11 +314,19 @@ def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
def
main
()
->
None
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--overwrite"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Overwrite files if they already exist"
)
parser
.
add_argument
(
"--output-dir"
,
default
=
"."
,
help
=
"Directory to write yaml files to"
)
parser
.
add_argument
(
"--overwrite"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Overwrite files if they already exist"
,
)
parser
.
add_argument
(
"--output-dir"
,
default
=
"."
,
help
=
"Directory to write yaml files to"
)
args
=
parser
.
parse_args
()
gen_lang_yamls
(
output_dir
=
args
.
output_dir
,
overwrite
=
args
.
overwrite
)
if
__name__
==
"__main__"
:
main
()
lm_eval/tasks/spanish_bench/flores_es/create_yamls_flores_es.py
View file @
5e0bc289
# ruff: noqa: E731, E741
"""
Script to generate task YAMLs for the FLORES-200 dataset.
Based on `tasks/translation/utils.py`.
"""
import
argparse
from
itertools
import
*
import
itertools
import
yaml
from
langcodes
import
*
from
langcodes
import
Language
# utils
...
...
tests/models/test_huggingface.py
View file @
5e0bc289
...
...
@@ -5,7 +5,9 @@ import sys
from
pathlib
import
Path
import
numpy
as
np
import
tokenizers
import
torch
from
packaging.version
import
parse
as
parse_version
from
lm_eval
import
tasks
from
lm_eval.api.instance
import
Instance
...
...
@@ -145,4 +147,7 @@ class Test_HFLM:
context
=
self
.
LM
.
tok_batch_encode
([
TEST_STRING
])[
0
]
res
=
self
.
LM
.
_model_generate
(
context
,
max_length
=
10
,
stop
=
[
"
\n\n
"
])
res
=
self
.
LM
.
tok_decode
(
res
[
0
])
assert
res
==
"foo bar
\n
<bazhang>!info bar"
if
parse_version
(
tokenizers
.
__version__
)
>=
parse_version
(
"0.20.0"
):
assert
res
==
"foo bar
\n
<bazhang> !info bar"
else
:
assert
res
==
"foo bar
\n
<bazhang>!info bar"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment