Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
50e99bd7
Commit
50e99bd7
authored
Sep 20, 2023
by
Herbie Bradley
Browse files
Merge remote-tracking branch 'origin/big-refactor' into calibration
parents
3d4c4cd6
a3252ed7
Changes
49
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
397 additions
and
251 deletions
+397
-251
.github/workflows/new_tasks.yml
.github/workflows/new_tasks.yml
+68
-68
.github/workflows/unit_tests.yml
.github/workflows/unit_tests.yml
+32
-33
lm_eval/api/model.py
lm_eval/api/model.py
+23
-3
lm_eval/api/task.py
lm_eval/api/task.py
+21
-14
lm_eval/benchmarks/__init__.py
lm_eval/benchmarks/__init__.py
+0
-63
lm_eval/decontamination/janitor.py
lm_eval/decontamination/janitor.py
+9
-9
lm_eval/evaluator.py
lm_eval/evaluator.py
+104
-37
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+6
-3
lm_eval/prompts/__init__.py
lm_eval/prompts/__init__.py
+11
-2
lm_eval/tasks/README.md
lm_eval/tasks/README.md
+2
-2
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+61
-13
lm_eval/tasks/benchmarks/pythia.yaml
lm_eval/tasks/benchmarks/pythia.yaml
+4
-4
lm_eval/tasks/benchmarks/t0_eval.yaml
lm_eval/tasks/benchmarks/t0_eval.yaml
+0
-0
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_bn.yaml
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_bn.yaml
+8
-0
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_de.yaml
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_de.yaml
+8
-0
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_en.yaml
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_en.yaml
+8
-0
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_es.yaml
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_es.yaml
+8
-0
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_fr.yaml
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_fr.yaml
+8
-0
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_ja.yaml
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_ja.yaml
+8
-0
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_ru.yaml
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_ru.yaml
+8
-0
No files found.
.github/workflows/new_tasks.yml
View file @
50e99bd7
name
:
Tasks Modified
#
name: Tasks Modified
on
:
#
on:
push
:
#
push:
branches
:
#
branches:
-
'
big-refactor*'
#
- 'big-refactor*'
pull_request
:
#
pull_request:
branches
:
#
branches:
-
'
big-refactor*'
#
- 'big-refactor*'
workflow_dispatch
:
#
workflow_dispatch:
# comment/edit out the above to stop/change the triggers
#
# comment/edit out the above to stop/change the triggers
jobs
:
#
jobs:
changed_files
:
#
changed_files:
runs-on
:
ubuntu-latest
# windows-latest || macos-latest
#
runs-on: ubuntu-latest # windows-latest || macos-latest
timeout-minutes
:
120
#
timeout-minutes: 120
name
:
Scan for changed tasks
#
name: Scan for changed tasks
steps
:
#
steps:
-
name
:
checkout
#
- name: checkout
uses
:
actions/checkout@v3
#
uses: actions/checkout@v3
with
:
#
with:
fetch-depth
:
2
# OR "2" -> To retrieve the preceding commit.
#
fetch-depth: 2 # OR "2" -> To retrieve the preceding commit.
# Uses the tj-actions/changed-files@v37 action to check for changes.
#
# Uses the tj-actions/changed-files@v37 action to check for changes.
# Outputs provided here: https://github.com/tj-actions/changed-files#outputs
#
# Outputs provided here: https://github.com/tj-actions/changed-files#outputs
# The `files_yaml` input optionally takes a yaml string to specify filters,
#
# The `files_yaml` input optionally takes a yaml string to specify filters,
# and prepends the filter name to the standard output names.
#
# and prepends the filter name to the standard output names.
-
name
:
Check task folders
#
- name: Check task folders
id
:
changed-tasks
#
id: changed-tasks
uses
:
tj-actions/changed-files@v37.1.2
#
uses: tj-actions/changed-files@v37.1.2
with
:
#
with:
# tasks checks the tasks folder and api checks the api folder for changes
#
# tasks checks the tasks folder and api checks the api folder for changes
files_yaml
:
|
#
files_yaml: |
tasks:
#
tasks:
- lm_eval/tasks/**
#
- lm_eval/tasks/**
api:
#
api:
- lm_eval/api/**
#
- lm_eval/api/**
write_output_files
:
true
#
write_output_files: true
# The next step is optional; the files are written to the workspace by default (above).
#
# The next step is optional; the files are written to the workspace by default (above).
# so it's just for debugging
#
# so it's just for debugging
-
name
:
Run Tests
#
- name: Run Tests
if
:
steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
#
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run
:
|
#
run: |
echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV'
#
echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV'
echo "One or more test file(s) has changed."
#
echo "One or more test file(s) has changed."
echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}"
#
echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}"
-
name
:
Set up Python
3.9
#
- name: Set up Python 3.9
if
:
steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
#
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
uses
:
actions/setup-python@v4
#
uses: actions/setup-python@v4
with
:
#
with:
python-version
:
3.9
#
python-version: 3.9
cache
:
'
pip'
#
cache: 'pip'
cache-dependency-path
:
setup.py
#
cache-dependency-path: setup.py
-
name
:
Install dependencies
#
- name: Install dependencies
if
:
steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
#
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run
:
|
#
run: |
python -m pip install --upgrade pip
#
python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
#
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies
#
# Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
#
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
#
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
-
name
:
Test with pytest
#
- name: Test with pytest
# if new tasks are added, run tests on them
#
# if new tasks are added, run tests on them
if
:
steps.changed-tasks.outputs.tasks_any_modified == 'true'
#
if: steps.changed-tasks.outputs.tasks_any_modified == 'true'
run
:
python -m pytest tests/test_tasks.py -s -vv
-n=auto
#
run: python -m pytest tests/test_tasks.py -s -vv
# if api is modified, run tests on it
#
# if api is modified, run tests on it
-
name
:
Test more tasks with pytest
#
- name: Test more tasks with pytest
env
:
#
env:
API
:
true
#
API: true
if
:
steps.changed-tasks.outputs.api_any_modified == 'true'
#
if: steps.changed-tasks.outputs.api_any_modified == 'true'
run
:
python -m pytest tests/test_tasks.py -s -vv
-n=auto
#
run: python -m pytest tests/test_tasks.py -s -vv
.github/workflows/unit_tests.yml
View file @
50e99bd7
...
@@ -40,39 +40,38 @@ jobs:
...
@@ -40,39 +40,38 @@ jobs:
flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero
flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# mypy turned off for now
#
# mypy turned off for now
# - name: Lint with mypy
# - name: Lint with mypy
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
# Job 2
# Job 2
testcpu
:
# testcpu:
name
:
CPU Tests
# name: CPU Tests
runs-on
:
ubuntu-latest
# runs-on: ubuntu-latest
strategy
:
# strategy:
matrix
:
# matrix:
python-version
:
[
"
3.9"
,
"
3.10"
,
"
3.11"
]
# python-version: [ "3.8", "3.9", "3.10", "3.11" ]
timeout-minutes
:
30
# timeout-minutes: 30
# steps:
steps
:
# - name: Checkout Code
-
name
:
Checkout Code
# uses: actions/checkout@v3
uses
:
actions/checkout@v3
# - name: Set up Python ${{ matrix.python-version }}
-
name
:
Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@v4
uses
:
actions/setup-python@v4
# with:
with
:
# python-version: ${{ matrix.python-version }}
python-version
:
${{ matrix.python-version }}
# cache: pip
cache
:
pip
# cache-dependency-path: setup.py
cache-dependency-path
:
setup.py
# - name: Install dependencies
-
name
:
Install dependencies
# run: |
run
:
|
# python -m pip install --upgrade pip
python -m pip install --upgrade pip
# pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
# # Install optional git dependencies
# Install optional git dependencies
# # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
# - name: Test with pytest
-
name
:
Test with pytest
# run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/tests_master --ignore=tests/extra
run
:
python -m pytest --showlocals -s -vv -n=auto --ignore=tests/tests_master --ignore=tests/extra
# - name: Archive artifacts
-
name
:
Archive artifacts
# uses: actions/upload-artifact@v3
uses
:
actions/upload-artifact@v3
# with:
with
:
# name: output_results
name
:
output_results
# path: |
path
:
|
# test_logs/*
test_logs/*
lm_eval/api/model.py
View file @
50e99bd7
import
abc
import
abc
import
os
import
os
from
typing
import
Union
,
List
,
Tuple
import
torch
from
typing
import
Union
,
List
,
Tuple
,
Optional
,
Type
,
TypeVar
from
sqlitedict
import
SqliteDict
from
sqlitedict
import
SqliteDict
import
json
import
json
import
hashlib
import
hashlib
...
@@ -11,6 +12,8 @@ from tqdm import tqdm
...
@@ -11,6 +12,8 @@ from tqdm import tqdm
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
T
=
TypeVar
(
"T"
,
bound
=
"LM"
)
class
LM
(
abc
.
ABC
):
class
LM
(
abc
.
ABC
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
...
@@ -111,11 +114,28 @@ class LM(abc.ABC):
...
@@ -111,11 +114,28 @@ class LM(abc.ABC):
pass
pass
@
classmethod
@
classmethod
def
create_from_arg_string
(
cls
,
arg_string
,
additional_config
=
None
):
def
create_from_arg_string
(
cls
:
Type
[
T
],
arg_string
:
str
,
additional_config
:
Optional
[
dict
]
=
None
)
->
T
:
"""
Creates an instance of the LM class using the given argument string and additional config.
Parameters:
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
- additional_config: Optional dictionary containing additional configuration parameters.
Returns:
- Instance of the LM class.
"""
additional_config
=
{}
if
additional_config
is
None
else
additional_config
additional_config
=
{}
if
additional_config
is
None
else
additional_config
args
=
utils
.
simple_parse_args_string
(
arg_string
)
args
=
utils
.
simple_parse_args_string
(
arg_string
)
args2
=
{
k
:
v
for
k
,
v
in
additional_config
.
items
()
if
v
is
not
None
}
args2
=
{
k
:
v
for
k
,
v
in
additional_config
.
items
()
if
v
is
not
None
}
if
args2
.
get
(
"device"
)
==
"mps"
or
args
.
get
(
"device"
)
==
"mps"
:
# TODO: delete once float16 MPS is fixed in torch stable
if
(
args2
.
get
(
"device"
)
in
(
"mps"
,
"mps:0"
)
or
args
.
get
(
"device"
)
in
(
"mps"
,
"mps:0"
)
and
"dev"
not
in
torch
.
__version__
):
args
[
"dtype"
]
=
"float32"
args
[
"dtype"
]
=
"float32"
return
cls
(
**
args
,
**
args2
)
return
cls
(
**
args
,
**
args2
)
...
...
lm_eval/api/task.py
View file @
50e99bd7
...
@@ -674,11 +674,11 @@ class ConfigurableTask(Task):
...
@@ -674,11 +674,11 @@ class ConfigurableTask(Task):
check_choices
=
test_choice
check_choices
=
test_choice
else
:
else
:
check_choices
=
[
test_target
]
check_choices
=
[
test_target
]
if
self
.
config
.
doc_to_choice
is
not
None
:
for
choice
in
check_choices
:
for
choice
in
check_choices
:
choice_has_whitespace
=
True
if
" "
in
choice
else
False
choice_has_whitespace
=
True
if
choice
[
0
].
isspace
()
else
False
delimiter_has_whitespace
=
(
delimiter_has_whitespace
=
(
True
if
" "
in
self
.
config
.
target_delimiter
else
False
True
if
self
.
config
.
target_delimiter
[
-
1
].
isspace
()
else
False
)
)
if
delimiter_has_whitespace
and
choice_has_whitespace
:
if
delimiter_has_whitespace
and
choice_has_whitespace
:
...
@@ -1080,6 +1080,9 @@ class ConfigurableTask(Task):
...
@@ -1080,6 +1080,9 @@ class ConfigurableTask(Task):
# it assumes that doc_to_target returns a number.
# it assumes that doc_to_target returns a number.
choices
=
self
.
doc_to_choice
(
doc
)
choices
=
self
.
doc_to_choice
(
doc
)
gold
=
choices
[
gold
]
gold
=
choices
[
gold
]
# we expect multiple_targets to be a list.
elif
self
.
multiple_target
:
gold
=
list
(
gold
)
else
:
else
:
gold
=
str
(
gold
)
gold
=
str
(
gold
)
...
@@ -1090,6 +1093,10 @@ class ConfigurableTask(Task):
...
@@ -1090,6 +1093,10 @@ class ConfigurableTask(Task):
# return true if any are true
# return true if any are true
# TODO: this may break for multipLe_target, non zero-or-1 metrics
# TODO: this may break for multipLe_target, non zero-or-1 metrics
scores
=
[]
scores
=
[]
if
not
isinstance
(
gold
,
list
):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold
=
[
gold
]
for
gold_option
in
gold
:
for
gold_option
in
gold
:
try
:
try
:
result_score
=
self
.
_metric_fn_list
[
metric
](
result_score
=
self
.
_metric_fn_list
[
metric
](
...
...
lm_eval/benchmarks/__init__.py
deleted
100644 → 0
View file @
3d4c4cd6
import
os
import
yaml
from
lm_eval
import
utils
from
lm_eval.tasks
import
register_configurable_task
,
check_prompt_config
from
lm_eval.logger
import
eval_logger
from
lm_eval.api.registry
import
(
TASK_REGISTRY
,
GROUP_REGISTRY
,
ALL_TASKS
,
)
def
include_benchmarks
(
task_dir
:
str
)
->
None
:
for
root
,
subdirs
,
file_list
in
os
.
walk
(
task_dir
):
if
(
subdirs
==
[]
or
subdirs
==
[
"__pycache__"
])
and
(
len
(
file_list
)
>
0
):
for
f
in
file_list
:
if
f
.
endswith
(
".yaml"
):
try
:
benchmark_path
=
os
.
path
.
join
(
root
,
f
)
with
open
(
benchmark_path
,
"rb"
)
as
file
:
yaml_config
=
yaml
.
full_load
(
file
)
assert
"group"
in
yaml_config
group
=
yaml_config
[
"group"
]
all_task_list
=
yaml_config
[
"task"
]
config_list
=
[
task
for
task
in
all_task_list
if
type
(
task
)
!=
str
]
task_list
=
[
task
for
task
in
all_task_list
if
type
(
task
)
==
str
]
for
task_config
in
config_list
:
var_configs
=
check_prompt_config
(
{
**
task_config
,
**
{
"group"
:
group
},
}
)
for
config
in
var_configs
:
register_configurable_task
(
config
)
task_names
=
utils
.
pattern_match
(
task_list
,
ALL_TASKS
)
for
task
in
task_names
:
if
task
in
TASK_REGISTRY
:
if
group
in
GROUP_REGISTRY
:
GROUP_REGISTRY
[
group
].
append
(
task
)
else
:
GROUP_REGISTRY
[
group
]
=
[
task
]
ALL_TASKS
.
add
(
group
)
except
Exception
as
error
:
eval_logger
.
warning
(
"Failed to load benchmark in
\n
"
f
"
{
benchmark_path
}
\n
"
" Benchmark will not be added to registry
\n
"
f
" Error:
{
error
}
"
)
task_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/"
include_benchmarks
(
task_dir
)
lm_eval/decontamination/janitor.py
View file @
50e99bd7
...
@@ -3,7 +3,7 @@ import string
...
@@ -3,7 +3,7 @@ import string
import
pickle
import
pickle
import
traceback
import
traceback
from
pprint
import
pprint
from
pprint
import
pprint
from
typing
import
Iterator
,
Sequence
,
TypeVar
from
typing
import
Iterator
,
Sequence
,
TypeVar
,
List
,
Tuple
# This is a cpp module. Compile janitor_util.cpp with:
# This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
...
@@ -21,7 +21,7 @@ T = TypeVar("T")
...
@@ -21,7 +21,7 @@ T = TypeVar("T")
# Implementation from nltk source
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
# https://www.nltk.org/_modules/nltk/util.html
def
form_ngrams
(
sequence
:
Iterator
[
T
],
n
:
int
)
->
Iterator
[
t
uple
[
T
,
...]]:
def
form_ngrams
(
sequence
:
Iterator
[
T
],
n
:
int
)
->
Iterator
[
T
uple
[
T
,
...]]:
history
=
[]
history
=
[]
while
n
>
1
:
while
n
>
1
:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
...
@@ -70,14 +70,14 @@ def word_ngrams(s: str, n: int) -> Iterator[str]:
...
@@ -70,14 +70,14 @@ def word_ngrams(s: str, n: int) -> Iterator[str]:
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def
split_indices
(
s
:
str
)
->
Iterator
[
t
uple
[
str
,
t
uple
[
int
,
int
]]]:
def
split_indices
(
s
:
str
)
->
Iterator
[
T
uple
[
str
,
T
uple
[
int
,
int
]]]:
"""Splits a string on whitespaces and records the indices of each in the original string.
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
@:return generator((word, (start_idx, end_idx)), ...)
"""
"""
return
((
m
.
group
(
0
),
(
m
.
start
(),
m
.
end
()
-
1
))
for
m
in
re
.
finditer
(
r
"\S+"
,
s
))
return
((
m
.
group
(
0
),
(
m
.
start
(),
m
.
end
()
-
1
))
for
m
in
re
.
finditer
(
r
"\S+"
,
s
))
def
word_ngrams_indices
(
s
:
str
,
n
:
int
)
->
Iterator
[
t
uple
[
str
,
t
uple
[
int
,
int
]]]:
def
word_ngrams_indices
(
s
:
str
,
n
:
int
)
->
Iterator
[
T
uple
[
str
,
T
uple
[
int
,
int
]]]:
"""Splits a string into pairs of (ngram words, their start/end indices)"""
"""Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices
=
split_indices
(
s
)
tokens_with_indices
=
split_indices
(
s
)
...
@@ -157,7 +157,7 @@ class Janitor:
...
@@ -157,7 +157,7 @@ class Janitor:
print
(
"WARNING: Janitor running in python mode"
)
print
(
"WARNING: Janitor running in python mode"
)
return
self
.
register_contaminant_python
(
dirt_string
)
return
self
.
register_contaminant_python
(
dirt_string
)
def
clean
(
self
,
dirty_string
:
str
)
->
l
ist
[
str
]:
def
clean
(
self
,
dirty_string
:
str
)
->
L
ist
[
str
]:
"""Clean a string (e.g. a training set) by removing all ngrams previously
"""Clean a string (e.g. a training set) by removing all ngrams previously
registered as contaminants. Returns a list of clean chunks, or empty if
registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty"""
the string was too dirty"""
...
@@ -168,8 +168,8 @@ class Janitor:
...
@@ -168,8 +168,8 @@ class Janitor:
return
self
.
clean_python
(
dirty_string
)
return
self
.
clean_python
(
dirty_string
)
def
_split_chunks
(
def
_split_chunks
(
self
,
dirty_string
:
str
,
dirty_parts
:
Sequence
[
t
uple
]
self
,
dirty_string
:
str
,
dirty_parts
:
Sequence
[
T
uple
]
)
->
l
ist
[
str
]:
)
->
L
ist
[
str
]:
clean_chunks
=
[]
clean_chunks
=
[]
splice_idx
=
0
splice_idx
=
0
end
=
-
1
end
=
-
1
...
@@ -197,7 +197,7 @@ class Janitor:
...
@@ -197,7 +197,7 @@ class Janitor:
janitor_util
.
clean_ngram
(
dirt_string
,
self
.
delete_chars
,
self
.
ngram_n
)
janitor_util
.
clean_ngram
(
dirt_string
,
self
.
delete_chars
,
self
.
ngram_n
)
)
)
def
clean_cpp
(
self
,
dirty_string
:
str
)
->
l
ist
[
str
]:
def
clean_cpp
(
self
,
dirty_string
:
str
)
->
L
ist
[
str
]:
contamination_indices
=
janitor_util
.
clean_ngram_with_indices
(
contamination_indices
=
janitor_util
.
clean_ngram_with_indices
(
dirty_string
,
self
.
delete_chars
,
self
.
ngram_n
dirty_string
,
self
.
delete_chars
,
self
.
ngram_n
)
)
...
@@ -215,7 +215,7 @@ class Janitor:
...
@@ -215,7 +215,7 @@ class Janitor:
word_ngrams
(
self
.
normalize_string
(
dirt_string
),
self
.
ngram_n
)
word_ngrams
(
self
.
normalize_string
(
dirt_string
),
self
.
ngram_n
)
)
)
def
clean_python
(
self
,
dirty_string
:
str
)
->
l
ist
[
str
]:
def
clean_python
(
self
,
dirty_string
:
str
)
->
L
ist
[
str
]:
contamination_indices
=
(
contamination_indices
=
(
(
None
,
*
idx_pair
)
(
None
,
*
idx_pair
)
for
dirty_ngram
,
idx_pair
in
word_ngrams_indices
(
dirty_string
,
self
.
ngram_n
)
for
dirty_ngram
,
idx_pair
in
word_ngrams_indices
(
dirty_string
,
self
.
ngram_n
)
...
...
lm_eval/evaluator.py
View file @
50e99bd7
...
@@ -118,6 +118,8 @@ def simple_evaluate(
...
@@ -118,6 +118,8 @@ def simple_evaluate(
task_obj
=
task_dict
[
task_name
]
task_obj
=
task_dict
[
task_name
]
if
type
(
task_obj
)
==
tuple
:
if
type
(
task_obj
)
==
tuple
:
group
,
task_obj
=
task_obj
group
,
task_obj
=
task_obj
if
task_obj
is
None
:
continue
config
=
task_obj
.
_config
config
=
task_obj
.
_config
if
num_fewshot
is
not
None
:
if
num_fewshot
is
not
None
:
...
@@ -207,23 +209,30 @@ def evaluate(
...
@@ -207,23 +209,30 @@ def evaluate(
samples
=
collections
.
defaultdict
(
list
)
samples
=
collections
.
defaultdict
(
list
)
# tracks all Instances/requests a model must generate output on.
# tracks all Instances/requests a model must generate output on.
requests
=
collections
.
defaultdict
(
list
)
requests
=
collections
.
defaultdict
(
list
)
#
Stores
task scores
based on task
group
ing.
#
Aggregated
task scores
presented with
group
s
aggregate
=
collections
.
defaultdict
(
dict
)
results_agg
=
collections
.
defaultdict
(
dict
)
#
tracks if a task was chosen via user selecting a group containing it
#
Aggregated groups scores only
task_
groups
=
collections
.
defaultdict
(
dict
)
groups
_agg
=
collections
.
defaultdict
(
dict
)
# stores the amount to pad out reqs per req. type so that
# stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal
# number of fwd passes per distributed rank is equal
padding_requests
=
collections
.
defaultdict
(
int
)
padding_requests
=
collections
.
defaultdict
(
int
)
# store the hierarchy to do proper ordering
# Stores group related keys and values for group-aggregation
task_hierarchy
=
collections
.
defaultdict
(
list
)
task_groups
=
collections
.
defaultdict
(
dict
)
# store the ordering of tasks and groups
task_order
=
collections
.
defaultdict
(
int
)
# store the aggregation for aggregating across tasks in the same group
sample_agg_fn
=
collections
.
defaultdict
(
dict
)
# get lists of each type of request
# get lists of each type of request
for
task_name
,
task
in
task_dict
.
items
():
for
task_name
,
task
in
task_dict
.
items
():
if
type
(
task
)
==
tuple
:
if
type
(
task
)
==
tuple
:
group
,
task
=
task
group_name
,
task
=
task
task_groups
[
task_name
]
=
group
task_hierarchy
[
group_name
].
append
(
task_name
)
aggregate
[
task_name
]
=
{}
else
:
task_hierarchy
[
task_name
]
=
[]
if
task
is
None
:
continue
versions
[
task_name
]
=
task
.
VERSION
versions
[
task_name
]
=
task
.
VERSION
configs
[
task_name
]
=
dict
(
task
.
dump_config
())
configs
[
task_name
]
=
dict
(
task
.
dump_config
())
...
@@ -299,6 +308,8 @@ def evaluate(
...
@@ -299,6 +308,8 @@ def evaluate(
for
task_name
,
task
in
task_dict
.
items
():
for
task_name
,
task
in
task_dict
.
items
():
if
type
(
task
)
==
tuple
:
if
type
(
task
)
==
tuple
:
group
,
task
=
task
group
,
task
=
task
if
task
is
None
:
continue
task
.
apply_filters
()
task
.
apply_filters
()
### Collect values of metrics on all datapoints ###
### Collect values of metrics on all datapoints ###
...
@@ -308,6 +319,8 @@ def evaluate(
...
@@ -308,6 +319,8 @@ def evaluate(
for
task_name
,
task
in
task_dict
.
items
():
for
task_name
,
task
in
task_dict
.
items
():
if
type
(
task
)
==
tuple
:
if
type
(
task
)
==
tuple
:
group
,
task
=
task
group
,
task
=
task
if
task
is
None
:
continue
# TODO: make it possible to use a different metric per filter
# TODO: make it possible to use a different metric per filter
# iterate over different filters used
# iterate over different filters used
for
key
in
task
.
instances
[
0
].
filtered_resps
.
keys
():
for
key
in
task
.
instances
[
0
].
filtered_resps
.
keys
():
...
@@ -468,27 +481,62 @@ def evaluate(
...
@@ -468,27 +481,62 @@ def evaluate(
vals
=
vals_torch
vals
=
vals_torch
if
lm
.
rank
==
0
:
if
lm
.
rank
==
0
:
### Get task ordering for correct sample-wide aggregation
group_to_task
=
{}
for
group
in
task_hierarchy
.
keys
():
if
group
not
in
task_order
:
task_order
[
group
]
=
0
if
len
(
task_hierarchy
[
group
])
>
0
:
group_to_task
[
group
]
=
task_hierarchy
[
group
].
copy
()
for
task
in
task_hierarchy
[
group
]:
if
task
in
task_order
:
task_order
[
task
]
+=
1
else
:
task_order
[
task
]
=
1
+
task_order
[
group
]
if
task
in
task_hierarchy
:
group_to_task
[
group
].
remove
(
task
)
group_to_task
[
group
].
extend
(
task_hierarchy
[
task
])
task_to_group
=
{}
for
group
in
group_to_task
:
for
task
in
group_to_task
[
group
]:
if
task
in
task_to_group
:
task_to_group
[
task
].
append
(
group
)
else
:
task_to_group
[
task
]
=
[
group
]
### Aggregate results over all datapoints ###
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
# aggregate results ; run bootstrap CIs
for
(
task_name
,
key
,
metric
),
items
in
vals
.
items
():
for
(
task_name
,
key
,
metric
),
items
in
vals
.
items
():
task
=
task_dict
[
task_name
]
task
=
task_dict
[
task_name
]
metric_key
=
metric
+
","
+
key
if
type
(
task
)
==
tuple
:
if
type
(
task
)
==
tuple
:
group
,
task
=
task
group
_name
,
task
=
task
task_score
=
task
.
aggregation
()[
metric
](
items
)
else
:
results
[
task_name
][
metric
+
","
+
key
]
=
task_scor
e
group_name
=
Non
e
# Need to put back in results
agg_fn
=
task
.
aggregation
()[
metric
]
# pythia | acc
task_score
=
agg_fn
(
items
)
# | perplexity
# | word_perplexity
if
group_name
is
not
None
:
# | byte_perplexity
sample_metric_key
=
metric
+
"(sample agg),"
+
key
# | bits_per_byte
for
grouping
in
task_to_group
[
task_name
]:
if
task_name
in
task_groups
:
if
metric_key
in
results
[
grouping
]:
group_name
=
task_groups
[
task_name
]
results
[
grouping
][
metric_key
].
append
(
task_score
)
if
metric
in
list
(
aggregate
[
group_name
].
keys
()):
aggregate
[
group_name
][
metric
].
append
(
task_score
)
else
:
else
:
aggregate
[
group_name
][
metric
]
=
[
task_score
]
results
[
grouping
][
metric_key
]
=
[
task_score
]
if
sample_metric_key
in
results
[
grouping
]:
results
[
grouping
][
sample_metric_key
]
+=
items
else
:
results
[
grouping
][
sample_metric_key
]
=
items
.
copy
()
sample_agg_fn
[
grouping
][
sample_metric_key
]
=
agg_fn
results
[
task_name
][
metric_key
]
=
task_score
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
# so we run them less iterations. still looking for a cleaner way to do this
...
@@ -503,19 +551,38 @@ def evaluate(
...
@@ -503,19 +551,38 @@ def evaluate(
if
stderr
is
not
None
:
if
stderr
is
not
None
:
results
[
task_name
][
metric
+
"_stderr"
+
","
+
key
]
=
stderr
(
items
)
results
[
task_name
][
metric
+
"_stderr"
+
","
+
key
]
=
stderr
(
items
)
if
bool
(
aggregate
):
if
bool
(
results
):
for
group
in
aggregate
.
keys
():
for
task_or_group
in
results
.
keys
():
for
metric
in
aggregate
[
group
].
keys
():
for
metric
in
results
[
task_or_group
].
keys
():
aggregate
[
group
][
metric
]
=
np
.
average
(
aggregate
[
group
][
metric
])
if
type
(
results
[
task_or_group
][
metric
])
==
list
:
versions
[
group
]
=
"N/A"
if
"(sample agg)"
in
metric
:
results
[
task_or_group
][
metric
]
=
sample_agg_fn
[
task_or_group
][
metric
](
results
[
task_or_group
][
metric
])
else
:
results
[
task_or_group
][
metric
]
=
np
.
average
(
results
[
task_or_group
][
metric
]
)
versions
[
task_or_group
]
=
"N/A"
for
task_name
,
task
in
task_dict
.
items
():
if
type
(
task
)
==
tuple
:
group_name
,
task
=
task
order
=
task_order
[
group_name
]
tabbed_name
=
"-"
*
order
+
group_name
results_agg
[
tabbed_name
]
=
results
[
group_name
]
versions
[
tabbed_name
]
=
versions
[
group_name
]
if
order
==
0
:
groups_agg
[
group_name
]
=
results
[
group_name
]
order
=
task_order
[
task_name
]
tabbed_name
=
"-"
*
order
+
task_name
results_agg
[
tabbed_name
]
=
results
[
task_name
]
versions
[
tabbed_name
]
=
versions
[
task_name
]
results_dict
=
{
results_dict
=
{
"results"
:
dict
(
sorted
(
results
.
items
())),
"results"
:
dict
(
results_agg
.
items
()),
**
(
**
({
"groups"
:
dict
(
groups_agg
.
items
())}
if
bool
(
groups_agg
)
else
{}),
{
"aggregate"
:
dict
(
sorted
(
aggregate
.
items
()))}
if
bool
(
aggregate
)
else
{}
),
"configs"
:
dict
(
sorted
(
configs
.
items
())),
"configs"
:
dict
(
sorted
(
configs
.
items
())),
"versions"
:
dict
(
sorted
(
versions
.
items
())),
"versions"
:
dict
(
sorted
(
versions
.
items
())),
}
}
...
...
lm_eval/models/huggingface.py
View file @
50e99bd7
...
@@ -101,17 +101,20 @@ class HFLM(LM):
...
@@ -101,17 +101,20 @@ class HFLM(LM):
if
not
(
parallelize
or
accelerator
.
num_processes
>
1
):
if
not
(
parallelize
or
accelerator
.
num_processes
>
1
):
# use user-passed device
# use user-passed device
device_list
=
set
(
device_list
=
set
(
[
"cuda"
,
"cpu"
,
"mps"
]
[
"cuda"
,
"cpu"
]
+
[
f
"cuda:
{
i
}
"
for
i
in
range
(
torch
.
cuda
.
device_count
())]
+
[
f
"cuda:
{
i
}
"
for
i
in
range
(
torch
.
cuda
.
device_count
())]
+
[
"mps"
,
"mps:0"
]
)
)
if
device
:
if
device
:
if
device
not
in
device_list
:
if
device
not
in
device_list
:
device
=
int
(
device
)
device
=
int
(
device
)
self
.
_device
=
torch
.
device
(
device
)
self
.
_device
=
torch
.
device
(
device
)
eval_logger
.
info
(
f
"Using device '
{
device
}
'"
)
eval_logger
.
info
(
f
"Using device '
{
device
}
'"
)
if
device
==
"mps"
:
if
device
in
(
"mps"
,
"mps:0"
)
and
"dev"
not
in
torch
.
__version__
:
eval_logger
.
info
(
eval_logger
.
info
(
"MPS is still in beta and only supports float32; setting dtype to float32."
"MPS: Setting dtype to float32. To use float16 with MPS, please install a nightly build of "
"PyTorch: pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cpu"
)
)
else
:
else
:
eval_logger
.
info
(
"Device not specified"
)
eval_logger
.
info
(
"Device not specified"
)
...
...
lm_eval/prompts/__init__.py
View file @
50e99bd7
import
ast
from
typing
import
Dict
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
...
@@ -5,7 +8,7 @@ from lm_eval.logger import eval_logger
...
@@ -5,7 +8,7 @@ from lm_eval.logger import eval_logger
# Stores prompts in a dictionary indexed by 2 levels:
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
# prompt category name, and prompt name.
# This allows us to access prompts
# This allows us to access prompts
PROMPT_REGISTRY
:
d
ict
[
str
,
d
ict
[
str
,
str
]]
=
{
PROMPT_REGISTRY
:
D
ict
[
str
,
D
ict
[
str
,
str
]]
=
{
"qa-basic"
:
{
"qa-basic"
:
{
"question-newline-answer"
:
"Question: {{question}}
\n
Answer:"
,
"question-newline-answer"
:
"Question: {{question}}
\n
Answer:"
,
"q-newline-a"
:
"Q: {{question}}
\n
A:"
,
"q-newline-a"
:
"Q: {{question}}
\n
A:"
,
...
@@ -63,6 +66,12 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa
...
@@ -63,6 +66,12 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa
else
:
else
:
prompts
=
DatasetTemplates
(
dataset_name
=
dataset_name
,
subset_name
=
subset_name
)
prompts
=
DatasetTemplates
(
dataset_name
=
dataset_name
,
subset_name
=
subset_name
)
category_name
,
prompt_name
=
use_prompt
.
split
(
":"
)
category_name
,
*
prompt_name
=
use_prompt
.
split
(
":"
)
# TODO allow to multiple prompt naming
# if len(prompt_name) > 1:
# prompt_list = []
# for prompt in prompt_name:
# prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
# else:
prompt_list
=
utils
.
pattern_match
(
prompt_name
,
prompts
.
all_template_names
)
prompt_list
=
utils
.
pattern_match
(
prompt_name
,
prompts
.
all_template_names
)
return
[
":"
.
join
([
category_name
,
prompt
])
for
prompt
in
prompt_list
]
return
[
":"
.
join
([
category_name
,
prompt
])
for
prompt
in
prompt_list
]
lm_eval/tasks/README.md
View file @
50e99bd7
...
@@ -16,7 +16,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
...
@@ -16,7 +16,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
-
[x] MCTACO
-
[x] MCTACO
-
[x] Pubmed QA
-
[x] Pubmed QA
-
[x] SciQ
-
[x] SciQ
-
[
] QASPER
-
[
x
] QASPER
-
[x] QA4MRE
-
[x] QA4MRE
-
[x] TriviaQA
-
[x] TriviaQA
-
[x] AI2 ARC
-
[x] AI2 ARC
...
@@ -36,7 +36,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
...
@@ -36,7 +36,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
-
[x] TruthfulQA (mc1)
-
[x] TruthfulQA (mc1)
-
[x] TruthfulQA (mc2)
-
[x] TruthfulQA (mc2)
-
[x] TruthfulQA (gen)
-
[x] TruthfulQA (gen)
-
[
] MuTual
-
[
x
] MuTual
-
[ ] Hendrycks Math (Hailey)
-
[ ] Hendrycks Math (Hailey)
-
[x] Asdiv
-
[x] Asdiv
-
[ ] GSM8k
-
[ ] GSM8k
...
...
lm_eval/tasks/__init__.py
View file @
50e99bd7
import
os
import
os
import
yaml
import
yaml
from
typing
import
List
,
Union
from
typing
import
List
,
Union
,
Dict
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval
import
prompts
from
lm_eval
import
prompts
...
@@ -15,7 +15,7 @@ from lm_eval.api.registry import (
...
@@ -15,7 +15,7 @@ from lm_eval.api.registry import (
)
)
def
register_configurable_task
(
config
:
d
ict
[
str
,
str
])
->
int
:
def
register_configurable_task
(
config
:
D
ict
[
str
,
str
])
->
int
:
SubClass
=
type
(
SubClass
=
type
(
config
[
"task"
]
+
"ConfigurableTask"
,
config
[
"task"
]
+
"ConfigurableTask"
,
(
ConfigurableTask
,),
(
ConfigurableTask
,),
...
@@ -38,7 +38,35 @@ def register_configurable_task(config: dict[str, str]) -> int:
...
@@ -38,7 +38,35 @@ def register_configurable_task(config: dict[str, str]) -> int:
return
0
return
0
def
check_prompt_config
(
config
:
dict
[
str
,
str
])
->
List
[
dict
[
str
,
str
]]:
def
register_configurable_group
(
config
:
Dict
[
str
,
str
])
->
int
:
group
=
config
[
"group"
]
all_task_list
=
config
[
"task"
]
config_list
=
[
task
for
task
in
all_task_list
if
type
(
task
)
!=
str
]
task_list
=
[
task
for
task
in
all_task_list
if
type
(
task
)
==
str
]
for
task_config
in
config_list
:
var_configs
=
check_prompt_config
(
{
**
task_config
,
**
{
"group"
:
group
},
}
)
for
config
in
var_configs
:
register_configurable_task
(
config
)
task_names
=
utils
.
pattern_match
(
task_list
,
ALL_TASKS
)
for
task
in
task_names
:
if
(
task
in
TASK_REGISTRY
)
or
(
task
in
GROUP_REGISTRY
):
if
group
in
GROUP_REGISTRY
:
GROUP_REGISTRY
[
group
].
append
(
task
)
else
:
GROUP_REGISTRY
[
group
]
=
[
task
]
ALL_TASKS
.
add
(
group
)
return
0
def
check_prompt_config
(
config
:
Dict
[
str
,
str
])
->
List
[
Dict
[
str
,
str
]]:
all_configs
=
[]
all_configs
=
[]
if
"use_prompt"
in
config
:
if
"use_prompt"
in
config
:
prompt_list
=
prompts
.
load_prompt_list
(
prompt_list
=
prompts
.
load_prompt_list
(
...
@@ -69,14 +97,14 @@ def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]:
...
@@ -69,14 +97,14 @@ def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]:
return
all_configs
return
all_configs
def
get_task_name_from_config
(
task_config
:
d
ict
[
str
,
str
])
->
str
:
def
get_task_name_from_config
(
task_config
:
D
ict
[
str
,
str
])
->
str
:
if
"dataset_name"
in
task_config
:
if
"dataset_name"
in
task_config
:
return
"{dataset_path}_{dataset_name}"
.
format
(
**
task_config
)
return
"{dataset_path}_{dataset_name}"
.
format
(
**
task_config
)
else
:
else
:
return
"{dataset_path}"
.
format
(
**
task_config
)
return
"{dataset_path}"
.
format
(
**
task_config
)
def
include_task_folder
(
task_dir
:
str
)
->
None
:
def
include_task_folder
(
task_dir
:
str
,
register_task
=
True
)
->
None
:
"""
"""
Calling this function
Calling this function
"""
"""
...
@@ -87,9 +115,16 @@ def include_task_folder(task_dir: str) -> None:
...
@@ -87,9 +115,16 @@ def include_task_folder(task_dir: str) -> None:
yaml_path
=
os
.
path
.
join
(
root
,
f
)
yaml_path
=
os
.
path
.
join
(
root
,
f
)
try
:
try
:
config
=
utils
.
load_yaml_config
(
yaml_path
)
config
=
utils
.
load_yaml_config
(
yaml_path
)
if
register_task
:
all_configs
=
check_prompt_config
(
config
)
all_configs
=
check_prompt_config
(
config
)
for
config
in
all_configs
:
for
config
in
all_configs
:
register_configurable_task
(
config
)
register_configurable_task
(
config
)
else
:
# If a `task` in config is a list,
# that means it's a benchmark
if
type
(
config
[
"task"
])
==
list
:
register_configurable_group
(
config
)
except
Exception
as
error
:
except
Exception
as
error
:
eval_logger
.
warning
(
eval_logger
.
warning
(
...
@@ -102,6 +137,8 @@ def include_task_folder(task_dir: str) -> None:
...
@@ -102,6 +137,8 @@ def include_task_folder(task_dir: str) -> None:
task_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/"
task_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/"
include_task_folder
(
task_dir
)
include_task_folder
(
task_dir
)
# Register Benchmarks after all tasks have been added
include_task_folder
(
task_dir
,
register_task
=
False
)
def
get_task
(
task_name
,
config
):
def
get_task
(
task_name
,
config
):
...
@@ -128,7 +165,7 @@ def get_task_name_from_object(task_object):
...
@@ -128,7 +165,7 @@ def get_task_name_from_object(task_object):
# TODO: pass num_fewshot and other cmdline overrides in a better way
# TODO: pass num_fewshot and other cmdline overrides in a better way
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
d
ict
,
Task
]],
**
kwargs
):
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
D
ict
,
Task
]],
**
kwargs
):
config
=
{
**
kwargs
}
config
=
{
**
kwargs
}
...
@@ -136,6 +173,9 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
...
@@ -136,6 +173,9 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
task_name_from_config_dict
=
{}
task_name_from_config_dict
=
{}
task_name_from_object_dict
=
{}
task_name_from_object_dict
=
{}
if
type
(
task_name_list
)
!=
list
:
task_name_list
=
[
task_name_list
]
for
task_element
in
task_name_list
:
for
task_element
in
task_name_list
:
if
isinstance
(
task_element
,
str
):
if
isinstance
(
task_element
,
str
):
...
@@ -143,12 +183,20 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
...
@@ -143,12 +183,20 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
group_name
=
task_element
group_name
=
task_element
for
task_name
in
GROUP_REGISTRY
[
task_element
]:
for
task_name
in
GROUP_REGISTRY
[
task_element
]:
if
task_name
not
in
task_name_from_registry_dict
:
if
task_name
not
in
task_name_from_registry_dict
:
task_obj
=
get_task_dict
(
task_name
)
if
task_name
in
task_obj
.
keys
():
task_dict
=
{
task_name
:
(
group_name
,
task_obj
[
task_name
]),
}
else
:
task_dict
=
{
task_name
:
(
group_name
,
None
),
**
task_obj
,
}
task_name_from_registry_dict
=
{
task_name_from_registry_dict
=
{
**
task_name_from_registry_dict
,
**
task_name_from_registry_dict
,
task_name
:
(
**
task_dict
,
group_name
,
get_task
(
task_name
=
task_name
,
config
=
config
),
),
}
}
else
:
else
:
task_name
=
task_element
task_name
=
task_element
...
...
lm_eval/benchmarks/pythia.yaml
→
lm_eval/
tasks/
benchmarks/pythia.yaml
View file @
50e99bd7
group
:
pythia
group
:
pythia
task
:
task
:
-
lambada_openai
-
lambada_openai
-
wikitext
-
logiqa
-
piqa
-
piqa
-
sciq
-
sciq
-
w
sc
-
w
ikitext
-
winogrande
-
winogrande
-
ar
c
-
ws
c
-
logiqa
-
ai2_arc
-
blimp
-
blimp
-
hendrycksTest*
-
hendrycksTest*
lm_eval/benchmarks/t0_eval.yaml
→
lm_eval/
tasks/
benchmarks/t0_eval.yaml
View file @
50e99bd7
File moved
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_bn.yaml
0 → 100644
View file @
50e99bd7
# Generated by utils.py
dataset_name
:
bn
doc_to_target
:
'
{%
if
answer
is
not
none
%}{{answer[16+1]}}{%
else
%}{{answer_number|string}}{%
endif
%}'
doc_to_text
:
'
{%
if
answer
is
not
none
%}{{question+"\nধাপে
ধাপে
উত্তর:"}}{%
else
%}{{"প্রশ্ন:
"+question+"\nধাপে
ধাপে
উত্তর:"}}{%
endif
%}'
include
:
cot_yaml
task
:
mgsm_bn_direct
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_de.yaml
0 → 100644
View file @
50e99bd7
# Generated by utils.py
dataset_name
:
de
doc_to_target
:
'
{%
if
answer
is
not
none
%}{{answer[28+1]}}{%
else
%}{{answer_number|string}}{%
endif
%}'
doc_to_text
:
'
{%
if
answer
is
not
none
%}{{question+"\nSchritt-für-Schritt-Antwort:"}}{%
else
%}{{"Frage:
"+question+"\nSchritt-für-Schritt-Antwort:"}}{%
endif
%}'
include
:
cot_yaml
task
:
mgsm_de_direct
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_en.yaml
0 → 100644
View file @
50e99bd7
# Generated by utils.py
dataset_name
:
en
doc_to_target
:
'
{%
if
answer
is
not
none
%}{{answer[20+1]}}{%
else
%}{{answer_number|string}}{%
endif
%}'
doc_to_text
:
'
{%
if
answer
is
not
none
%}{{question+"\nStep-by-Step
Answer:"}}{%
else
%}{{"Question:
"+question+"\nStep-by-Step
Answer:"}}{%
endif
%}'
include
:
cot_yaml
task
:
mgsm_en_direct
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_es.yaml
0 → 100644
View file @
50e99bd7
# Generated by utils.py
dataset_name
:
es
doc_to_target
:
'
{%
if
answer
is
not
none
%}{{answer[22+1]}}{%
else
%}{{answer_number|string}}{%
endif
%}'
doc_to_text
:
'
{%
if
answer
is
not
none
%}{{question+"\nRespuesta
paso
a
paso:"}}{%
else
%}{{"Pregunta:
"+question+"\nRespuesta
paso
a
paso:"}}{%
endif
%}'
include
:
cot_yaml
task
:
mgsm_es_direct
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_fr.yaml
0 → 100644
View file @
50e99bd7
# Generated by utils.py
dataset_name
:
fr
doc_to_target
:
'
{%
if
answer
is
not
none
%}{{answer[25+1]}}{%
else
%}{{answer_number|string}}{%
endif
%}'
doc_to_text
:
'
{%
if
answer
is
not
none
%}{{question+"\nRéponse
étape
par
étape
:"}}{%
else
%}{{"Question
:
"+question+"\nRéponse
étape
par
étape
:"}}{%
endif
%}'
include
:
cot_yaml
task
:
mgsm_fr_direct
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_ja.yaml
0 → 100644
View file @
50e99bd7
# Generated by utils.py
dataset_name
:
ja
doc_to_target
:
'
{%
if
answer
is
not
none
%}{{answer[10+1]}}{%
else
%}{{answer_number|string}}{%
endif
%}'
doc_to_text
:
'
{%
if
answer
is
not
none
%}{{question+"\nステップごとの答え:"}}{%
else
%}{{"問題:
"+question+"\nステップごとの答え:"}}{%
endif
%}'
include
:
cot_yaml
task
:
mgsm_ja_direct
lm_eval/tasks/mgsm/native_cot/mgsm_cot_native_ru.yaml
0 → 100644
View file @
50e99bd7
# Generated by utils.py
dataset_name
:
ru
doc_to_target
:
'
{%
if
answer
is
not
none
%}{{answer[17+1]}}{%
else
%}{{answer_number|string}}{%
endif
%}'
doc_to_text
:
'
{%
if
answer
is
not
none
%}{{question+"\nПошаговоерешение:"}}{%
else
%}{{"Задача:
"+question+"\nПошаговоерешение:"}}{%
endif
%}'
include
:
cot_yaml
task
:
mgsm_ru_direct
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment