"docs/git@developer.sourcefind.cn:norm/vllm.git" did not exist on "c9fadda54353f1b57c3dae9b7cbebda6f0767f8e"
Unverified Commit f88ffeee authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge branch 'big-refactor' into add-fewshot-config

parents 2d5d94da 0f6cd358
name: Tasks Modified # name: Tasks Modified
on: # on:
push: # push:
branches: # branches:
- 'big-refactor*' # - 'big-refactor*'
pull_request: # pull_request:
branches: # branches:
- 'big-refactor*' # - 'big-refactor*'
workflow_dispatch: # workflow_dispatch:
# comment/edit out the above to stop/change the triggers # # comment/edit out the above to stop/change the triggers
jobs: # jobs:
changed_files: # changed_files:
runs-on: ubuntu-latest # windows-latest || macos-latest # runs-on: ubuntu-latest # windows-latest || macos-latest
timeout-minutes: 120 # timeout-minutes: 120
name: Scan for changed tasks # name: Scan for changed tasks
steps: # steps:
- name: checkout # - name: checkout
uses: actions/checkout@v3 # uses: actions/checkout@v3
with: # with:
fetch-depth: 2 # OR "2" -> To retrieve the preceding commit. # fetch-depth: 2 # OR "2" -> To retrieve the preceding commit.
# Uses the tj-actions/changed-files@v37 action to check for changes. # # Uses the tj-actions/changed-files@v37 action to check for changes.
# Outputs provided here: https://github.com/tj-actions/changed-files#outputs # # Outputs provided here: https://github.com/tj-actions/changed-files#outputs
# The `files_yaml` input optionally takes a yaml string to specify filters, # # The `files_yaml` input optionally takes a yaml string to specify filters,
# and prepends the filter name to the standard output names. # # and prepends the filter name to the standard output names.
- name: Check task folders # - name: Check task folders
id: changed-tasks # id: changed-tasks
uses: tj-actions/changed-files@v37.1.2 # uses: tj-actions/changed-files@v37.1.2
with: # with:
# tasks checks the tasks folder and api checks the api folder for changes # # tasks checks the tasks folder and api checks the api folder for changes
files_yaml: | # files_yaml: |
tasks: # tasks:
- lm_eval/tasks/** # - lm_eval/tasks/**
api: # api:
- lm_eval/api/** # - lm_eval/api/**
write_output_files: true # write_output_files: true
# The next step is optional; the files are written to the workspace by default (above). # # The next step is optional; the files are written to the workspace by default (above).
# so it's just for debugging # # so it's just for debugging
- name: Run Tests # - name: Run Tests
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' # if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: | # run: |
echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV' # echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV'
echo "One or more test file(s) has changed." # echo "One or more test file(s) has changed."
echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}" # echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}"
- name: Set up Python 3.9 # - name: Set up Python 3.9
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' # if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
uses: actions/setup-python@v4 # uses: actions/setup-python@v4
with: # with:
python-version: 3.9 # python-version: 3.9
cache: 'pip' # cache: 'pip'
cache-dependency-path: setup.py # cache-dependency-path: setup.py
- name: Install dependencies # - name: Install dependencies
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' # if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: | # run: |
python -m pip install --upgrade pip # python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu # pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies # # Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest # - name: Test with pytest
# if new tasks are added, run tests on them # # if new tasks are added, run tests on them
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' # if: steps.changed-tasks.outputs.tasks_any_modified == 'true'
run: python -m pytest tests/test_tasks.py -s -vv # run: python -m pytest tests/test_tasks.py -s -vv
# if api is modified, run tests on it # # if api is modified, run tests on it
- name: Test more tasks with pytest # - name: Test more tasks with pytest
env: # env:
API: true # API: true
if: steps.changed-tasks.outputs.api_any_modified == 'true' # if: steps.changed-tasks.outputs.api_any_modified == 'true'
run: python -m pytest tests/test_tasks.py -s -vv # run: python -m pytest tests/test_tasks.py -s -vv
...@@ -40,39 +40,38 @@ jobs: ...@@ -40,39 +40,38 @@ jobs:
flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# mypy turned off for now # # mypy turned off for now
# - name: Lint with mypy # - name: Lint with mypy
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable # run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
# Job 2 # Job 2
testcpu: # testcpu:
name: CPU Tests # name: CPU Tests
runs-on: ubuntu-latest # runs-on: ubuntu-latest
strategy: # strategy:
matrix: # matrix:
python-version: [ "3.9", "3.10", "3.11" ] # python-version: [ "3.8", "3.9", "3.10", "3.11" ]
timeout-minutes: 30 # timeout-minutes: 30
# steps:
steps: # - name: Checkout Code
- name: Checkout Code # uses: actions/checkout@v3
uses: actions/checkout@v3 # - name: Set up Python ${{ matrix.python-version }}
- name: Set up Python ${{ matrix.python-version }} # uses: actions/setup-python@v4
uses: actions/setup-python@v4 # with:
with: # python-version: ${{ matrix.python-version }}
python-version: ${{ matrix.python-version }} # cache: pip
cache: pip # cache-dependency-path: setup.py
cache-dependency-path: setup.py # - name: Install dependencies
- name: Install dependencies # run: |
run: | # python -m pip install --upgrade pip
python -m pip install --upgrade pip # pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu # # Install optional git dependencies
# Install optional git dependencies # # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # - name: Test with pytest
- name: Test with pytest # run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/tests_master --ignore=tests/extra
run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/tests_master --ignore=tests/extra # - name: Archive artifacts
- name: Archive artifacts # uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v3 # with:
with: # name: output_results
name: output_results # path: |
path: | # test_logs/*
test_logs/*
...@@ -40,7 +40,7 @@ repos: ...@@ -40,7 +40,7 @@ repos:
- id: codespell - id: codespell
exclude: > exclude: >
(?x)^( (?x)^(
.*\.json|ignore.txt|.*yaml .*\.json|ignore.txt|lm_eval/tasks/.*|.*yaml
)$ )$
args: [--check-filenames, --check-hidden, --ignore-words=ignore.txt] args: [--check-filenames, --check-hidden, --ignore-words=ignore.txt]
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
......
...@@ -46,14 +46,14 @@ class ContextSampler: ...@@ -46,14 +46,14 @@ class ContextSampler:
) )
+ self.target_delimiter + self.target_delimiter
+ ( + (
self.doc_to_target(doc)[0] str(self.doc_to_target(doc)[0])
if type(self.doc_to_target(doc)) is list if type(self.doc_to_target(doc)) is list
else self.doc_to_target(doc) else self.doc_to_target(doc)
if ( if (
self.config.doc_to_choice is None self.config.doc_to_choice is None
or type(self.doc_to_target(doc)) is str or type(self.doc_to_target(doc)) is str
) )
else self.doc_to_choice(doc)[self.doc_to_target(doc)] else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
) )
for doc in selected_docs for doc in selected_docs
] ]
......
...@@ -582,7 +582,7 @@ class ConfigurableTask(Task): ...@@ -582,7 +582,7 @@ class ConfigurableTask(Task):
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()} INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_default_aggregation(metric_name) metric_agg = get_default_aggregation(metric_name)
eval_logger.warning( eval_logger.warning(
f"metric {metric_name} is defined, but aggregation is not. " f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default " f"using default "
f"aggregation={INV_AGG_REGISTRY[metric_agg]}" f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
) )
...@@ -594,7 +594,7 @@ class ConfigurableTask(Task): ...@@ -594,7 +594,7 @@ class ConfigurableTask(Task):
] ]
else: else:
eval_logger.warning( eval_logger.warning(
f"metric {metric_name} is defined, but higher_is_better is not. " f"[Task: {self._config.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default " f"using default "
f"higher_is_better={is_higher_better(metric_name)}" f"higher_is_better={is_higher_better(metric_name)}"
) )
...@@ -839,7 +839,10 @@ class ConfigurableTask(Task): ...@@ -839,7 +839,10 @@ class ConfigurableTask(Task):
and (target_string[0] == "[") and (target_string[0] == "[")
and (target_string[-1] == "]") and (target_string[-1] == "]")
): ):
return ast.literal_eval(target_string) try:
return ast.literal_eval(target_string)
except (SyntaxError, ValueError):
return target_string
else: else:
return target_string return target_string
elif type(doc_to_target) == list: elif type(doc_to_target) == list:
......
import os
import yaml
from lm_eval import utils
from lm_eval.tasks import register_configurable_task, check_prompt_config
from lm_eval.logger import eval_logger
from lm_eval.api.registry import (
TASK_REGISTRY,
GROUP_REGISTRY,
ALL_TASKS,
)
def include_benchmarks(task_dir: str) -> None:
for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list:
if f.endswith(".yaml"):
try:
benchmark_path = os.path.join(root, f)
with open(benchmark_path, "rb") as file:
yaml_config = yaml.full_load(file)
assert "group" in yaml_config
group = yaml_config["group"]
all_task_list = yaml_config["task"]
config_list = [
task for task in all_task_list if type(task) != str
]
task_list = [
task for task in all_task_list if type(task) == str
]
for task_config in config_list:
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
)
for config in var_configs:
register_configurable_task(config)
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
except Exception as error:
eval_logger.warning(
"Failed to load benchmark in\n"
f" {benchmark_path}\n"
" Benchmark will not be added to registry\n"
f" Error: {error}"
)
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_benchmarks(task_dir)
group: minerva_math
task:
- minerva_math_algebra
- minerva_math_counting_and_prob
- minerva_math_geometry
- minerva_math_intermediate_algebra
- minerva_math_num_theory
- minerva_math_prealgebra
- minerva_math_precalc
...@@ -3,7 +3,7 @@ import string ...@@ -3,7 +3,7 @@ import string
import pickle import pickle
import traceback import traceback
from pprint import pprint from pprint import pprint
from typing import Iterator, Sequence, TypeVar from typing import Iterator, Sequence, TypeVar, List, Tuple
# This is a cpp module. Compile janitor_util.cpp with: # This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
...@@ -21,7 +21,7 @@ T = TypeVar("T") ...@@ -21,7 +21,7 @@ T = TypeVar("T")
# Implementation from nltk source # Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html # https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[tuple[T, ...]]: def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]:
history = [] history = []
while n > 1: while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
...@@ -70,14 +70,14 @@ def word_ngrams(s: str, n: int) -> Iterator[str]: ...@@ -70,14 +70,14 @@ def word_ngrams(s: str, n: int) -> Iterator[str]:
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python # https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def split_indices(s: str) -> Iterator[tuple[str, tuple[int, int]]]: def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]:
"""Splits a string on whitespaces and records the indices of each in the original string. """Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...) @:return generator((word, (start_idx, end_idx)), ...)
""" """
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s)) return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s: str, n: int) -> Iterator[tuple[str, tuple[int, int]]]: def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]:
"""Splits a string into pairs of (ngram words, their start/end indices)""" """Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices = split_indices(s) tokens_with_indices = split_indices(s)
...@@ -157,7 +157,7 @@ class Janitor: ...@@ -157,7 +157,7 @@ class Janitor:
print("WARNING: Janitor running in python mode") print("WARNING: Janitor running in python mode")
return self.register_contaminant_python(dirt_string) return self.register_contaminant_python(dirt_string)
def clean(self, dirty_string: str) -> list[str]: def clean(self, dirty_string: str) -> List[str]:
"""Clean a string (e.g. a training set) by removing all ngrams previously """Clean a string (e.g. a training set) by removing all ngrams previously
registered as contaminants. Returns a list of clean chunks, or empty if registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty""" the string was too dirty"""
...@@ -168,8 +168,8 @@ class Janitor: ...@@ -168,8 +168,8 @@ class Janitor:
return self.clean_python(dirty_string) return self.clean_python(dirty_string)
def _split_chunks( def _split_chunks(
self, dirty_string: str, dirty_parts: Sequence[tuple] self, dirty_string: str, dirty_parts: Sequence[Tuple]
) -> list[str]: ) -> List[str]:
clean_chunks = [] clean_chunks = []
splice_idx = 0 splice_idx = 0
end = -1 end = -1
...@@ -197,7 +197,7 @@ class Janitor: ...@@ -197,7 +197,7 @@ class Janitor:
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n) janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
) )
def clean_cpp(self, dirty_string: str) -> list[str]: def clean_cpp(self, dirty_string: str) -> List[str]:
contamination_indices = janitor_util.clean_ngram_with_indices( contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n dirty_string, self.delete_chars, self.ngram_n
) )
...@@ -215,7 +215,7 @@ class Janitor: ...@@ -215,7 +215,7 @@ class Janitor:
word_ngrams(self.normalize_string(dirt_string), self.ngram_n) word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
) )
def clean_python(self, dirty_string: str) -> list[str]: def clean_python(self, dirty_string: str) -> List[str]:
contamination_indices = ( contamination_indices = (
(None, *idx_pair) (None, *idx_pair)
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n) for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
......
...@@ -11,7 +11,6 @@ import numpy as np ...@@ -11,7 +11,6 @@ import numpy as np
import lm_eval.api import lm_eval.api
import lm_eval.tasks import lm_eval.tasks
import lm_eval.benchmarks
import lm_eval.models import lm_eval.models
import lm_eval.api.metrics import lm_eval.api.metrics
import lm_eval.api.registry import lm_eval.api.registry
......
...@@ -508,7 +508,7 @@ class HFLM(LM): ...@@ -508,7 +508,7 @@ class HFLM(LM):
self.tokenizer, stop, 1, context.shape[0] self.tokenizer, stop, 1, context.shape[0]
) )
return self.model.generate( return self.model.generate(
context, input_ids=context,
max_length=max_length, max_length=max_length,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id, pad_token_id=self.eot_token_id,
......
...@@ -16,7 +16,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -16,7 +16,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] MCTACO - [x] MCTACO
- [x] Pubmed QA - [x] Pubmed QA
- [x] SciQ - [x] SciQ
- [ ] QASPER - [x] QASPER
- [x] QA4MRE - [x] QA4MRE
- [x] TriviaQA - [x] TriviaQA
- [x] AI2 ARC - [x] AI2 ARC
...@@ -36,7 +36,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -36,7 +36,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] TruthfulQA (mc1) - [x] TruthfulQA (mc1)
- [x] TruthfulQA (mc2) - [x] TruthfulQA (mc2)
- [x] TruthfulQA (gen) - [x] TruthfulQA (gen)
- [ ] MuTual - [x] MuTual
- [ ] Hendrycks Math (Hailey) - [ ] Hendrycks Math (Hailey)
- [x] Asdiv - [x] Asdiv
- [ ] GSM8k - [ ] GSM8k
......
...@@ -38,6 +38,34 @@ def register_configurable_task(config: Dict[str, str]) -> int: ...@@ -38,6 +38,34 @@ def register_configurable_task(config: Dict[str, str]) -> int:
return 0 return 0
def register_configurable_group(config: Dict[str, str]) -> int:
group = config["group"]
all_task_list = config["task"]
config_list = [task for task in all_task_list if type(task) != str]
task_list = [task for task in all_task_list if type(task) == str]
for task_config in config_list:
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
)
for config in var_configs:
register_configurable_task(config)
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
return 0
def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]: def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]:
all_configs = [] all_configs = []
if "use_prompt" in config: if "use_prompt" in config:
...@@ -76,7 +104,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str: ...@@ -76,7 +104,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
return "{dataset_path}".format(**task_config) return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir: str) -> None: def include_task_folder(task_dir: str, register_task=True) -> None:
""" """
Calling this function Calling this function
""" """
...@@ -87,9 +115,16 @@ def include_task_folder(task_dir: str) -> None: ...@@ -87,9 +115,16 @@ def include_task_folder(task_dir: str) -> None:
yaml_path = os.path.join(root, f) yaml_path = os.path.join(root, f)
try: try:
config = utils.load_yaml_config(yaml_path) config = utils.load_yaml_config(yaml_path)
all_configs = check_prompt_config(config)
for config in all_configs: if register_task:
register_configurable_task(config) all_configs = check_prompt_config(config)
for config in all_configs:
register_configurable_task(config)
else:
# If a `task` in config is a list,
# that means it's a benchmark
if type(config["task"]) == list:
register_configurable_group(config)
except Exception as error: except Exception as error:
eval_logger.warning( eval_logger.warning(
...@@ -102,6 +137,8 @@ def include_task_folder(task_dir: str) -> None: ...@@ -102,6 +137,8 @@ def include_task_folder(task_dir: str) -> None:
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/" task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_task_folder(task_dir) include_task_folder(task_dir)
# Register Benchmarks after all tasks have been added
include_task_folder(task_dir, register_task=False)
def get_task(task_name, config): def get_task(task_name, config):
......
group: csatqa
dataset_path: EleutherAI/csatqa
test_split: test
output_type: multiple_choice
process_docs: !function utils.process_docs
doc_to_text: "{{question}}"
doc_to_choice: "{{choices}}"
doc_to_target: "{{gold}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
"""
Take in a YAML, and output all other splits with this YAML
"""
import os
import yaml
import argparse
from tqdm import tqdm
from lm_eval.logger import eval_logger
SUBSETS = ["WR", "GR", "RCS", "RCSS", "RCH", "LI"]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_yaml_path", required=True)
parser.add_argument("--save_prefix_path", default="csatqa")
parser.add_argument("--task_prefix", default="")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
base_yaml_name = os.path.split(args.base_yaml_path)[-1]
with open(args.base_yaml_path) as f:
base_yaml = yaml.full_load(f)
for name in tqdm(SUBSETS):
yaml_dict = {
"include": base_yaml_name,
"task": f"csatqa_{args.task_prefix}_{name}"
if args.task_prefix != ""
else f"csatqa_{name.lower()}",
"dataset_name": name,
}
file_save_path = args.save_prefix_path + f"_{name.lower()}.yaml"
eval_logger.info(f"Saving yaml for subset {name} to {file_save_path}")
with open(file_save_path, "w") as yaml_file:
yaml.dump(
yaml_dict,
yaml_file,
width=float("inf"),
allow_unicode=True,
default_style='"',
)
"dataset_name": "GR"
"include": "_default_csatqa_yaml"
"task": "csatqa_gr"
"dataset_name": "LI"
"include": "_default_csatqa_yaml"
"task": "csatqa_li"
"dataset_name": "RCH"
"include": "_default_csatqa_yaml"
"task": "csatqa_rch"
"dataset_name": "RCS"
"include": "_default_csatqa_yaml"
"task": "csatqa_rcs"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment