Commit c4c20ff5 authored by lintangsutawika's avatar lintangsutawika
Browse files

pre-commit stuff

parent e56b950a
......@@ -21,4 +21,4 @@ _CITATION = """
year={2018},
volume={abs/1803.05457}
}
"""
\ No newline at end of file
"""
......@@ -9,7 +9,7 @@ validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: true
- metric: acc_mutual_info
aggregation: mean
higher_is_better: true
\ No newline at end of file
higher_is_better: true
......@@ -9,7 +9,7 @@ validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: true
- metric: acc_mutual_info
aggregation: mean
higher_is_better: true
\ No newline at end of file
higher_is_better: true
......@@ -17,9 +17,10 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math
"""
import re
from lm_eval import utils
from lm_eval.api.task import Task
from lm_eval.api.metrics import mean
from lm_eval.api.instance import Instance
from lm_eval.api.task import Task
from lm_eval.prompts import get_prompt
......@@ -88,7 +89,13 @@ class GradeSchoolMath8K(Task):
"""
# NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution.
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, ["\n"]), idx=0, **kwargs)
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, ["\n"]),
idx=0,
**kwargs
)
# completion = rf.greedy_until(ctx, ["\n"])
# return completion
......
......@@ -60,11 +60,18 @@ class LambadaBase(Task):
return " " + doc["text"].rsplit(" ", 1)[1]
def construct_requests(self, doc, ctx, **kwargs):
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, self.doc_to_target(doc)), **kwargs)
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, self.doc_to_target(doc)),
**kwargs
)
def process_results(self, doc, results):
# TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score
results = results[0] # TODO: recheck this. currently a list of [(ll, is_greedy)] is passed in
results = results[
0
] # TODO: recheck this. currently a list of [(ll, is_greedy)] is passed in
ll, is_greedy = results
return {"ppl": ll, "acc": int(is_greedy)}
......
......@@ -35,7 +35,7 @@ class PilePerplexityTask(PerplexityTask):
def test_docs(self):
for doc in self.dataset["train"].select(range(100)):
yield doc
def has_validation_docs(self):
return False
......@@ -140,4 +140,4 @@ class PileWikipedia(PilePerplexityTask):
class PileYoutubeSubtitles(PilePerplexityTask):
DATASET_NAME = "pile_youtubesubtitles"
\ No newline at end of file
DATASET_NAME = "pile_youtubesubtitles"
......@@ -37,4 +37,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
import re
def doc_to_text(x):
def doc_to_text(x):
def _mark_span(text, span_str, span_idx, mark):
pattern_tmpl = r'^((?:\S+\s){N})(W)'
pattern = re.sub('N', str(span_idx), pattern_tmpl)
pattern = re.sub('W', span_str, pattern)
return re.sub(pattern, r'\1{0} \2 {0}'.format(mark), text)
pattern_tmpl = r"^((?:\S+\s){N})(W)"
pattern = re.sub("N", str(span_idx), pattern_tmpl)
pattern = re.sub("W", span_str, pattern)
return re.sub(pattern, r"\1{0} \2 {0}".format(mark), text)
text = x['text']
text = _mark_span(text, x['span1_text'], x['span1_index'], '*')
text = x["text"]
text = _mark_span(text, x["span1_text"], x["span1_index"], "*")
# Compensate for 2 added "words" added in previous step.
span2_index = x['span2_index'] + 2 * (x['span1_index'] < x['span2_index'])
text = _mark_span(text, x['span2_text'], span2_index, '#')
span2_index = x["span2_index"] + 2 * (x["span1_index"] < x["span2_index"])
text = _mark_span(text, x["span2_text"], span2_index, "#")
return text
\ No newline at end of file
return text
......@@ -60,6 +60,7 @@ def wikitext_detokenizer(string):
return string
@register_task("wikitext")
class WikiText(PerplexityTask):
VERSION = "2.0"
......
......@@ -150,7 +150,6 @@ class Reorderer:
return res
def make_table(result_dict):
"""Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter
......@@ -262,7 +261,7 @@ def import_function(loader, node):
function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name)
module_name, function_name = function_name.split('.')
module_name, function_name = function_name.split(".")
module_path = os.path.join(yaml_path, "{}.py".format(module_name))
spec = importlib.util.spec_from_file_location(module_name, module_path)
......@@ -272,29 +271,30 @@ def import_function(loader, node):
function = getattr(module, function_name)
return function
# Add the import_function constructor to the YAML loader
yaml.add_constructor('!function', import_function)
yaml.add_constructor("!function", import_function)
def load_yaml_config(yaml_path):
with open(yaml_path, 'rb') as file:
with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file)
yaml_dir = os.path.dirname(yaml_path)
if 'include' in yaml_config:
include_path = yaml_config['include']
del yaml_config['include']
if "include" in yaml_config:
include_path = yaml_config["include"]
del yaml_config["include"]
if type(include_path) == str:
include_path = [include_path]
# Load from the last one first
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not os.path.isfile(path):
path = os.path.join(yaml_dir, path)
......@@ -302,9 +302,9 @@ def load_yaml_config(yaml_path):
try:
included_yaml_config = load_yaml_config(path)
final_yaml_config.update(included_yaml_config)
except:
except Exception as ex:
# If failed to load, ignore
pass
raise ex
final_yaml_config.update(yaml_config)
return final_yaml_config
......@@ -313,7 +313,7 @@ def load_yaml_config(yaml_path):
env = Environment(loader=BaseLoader, undefined=StrictUndefined)
def apply_template(template, doc):
rtemplate = env.from_string(template)
return rtemplate.render(**doc)
......@@ -7,7 +7,8 @@ from lm_eval import evaluator, utils
from lm_eval.tasks import ALL_TASKS
from lm_eval.logger import eval_logger
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class MultiChoice:
def __init__(self, choices):
......@@ -65,9 +66,10 @@ def main():
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.tasks != None:
if args.tasks is not None:
if os.path.isdir(args.tasks):
import glob
task_names = []
yaml_path = os.path.join(args.tasks, "*.yaml")
for yaml_file in glob.glob(yaml_path):
......@@ -80,7 +82,7 @@ def main():
if os.path.isfile(task):
config = utils.load_yaml_config(task)
task_names.append(config)
eval_logger.info(f"Selected Tasks: {task_names}")
results = evaluator.simple_evaluate(
......
......@@ -42,6 +42,6 @@ setuptools.setup(
extras_require={
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"]
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment