Commit c4c20ff5 authored by lintangsutawika's avatar lintangsutawika
Browse files

pre-commit stuff

parent e56b950a
...@@ -21,4 +21,4 @@ _CITATION = """ ...@@ -21,4 +21,4 @@ _CITATION = """
year={2018}, year={2018},
volume={abs/1803.05457} volume={abs/1803.05457}
} }
""" """
\ No newline at end of file
...@@ -9,7 +9,7 @@ validation_split: validation ...@@ -9,7 +9,7 @@ validation_split: validation
test_split: test test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what) template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:" doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int. doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
higher_is_better: true higher_is_better: true
- metric: acc_mutual_info - metric: acc_mutual_info
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
\ No newline at end of file
...@@ -9,7 +9,7 @@ validation_split: validation ...@@ -9,7 +9,7 @@ validation_split: validation
test_split: test test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what) template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:" doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int. doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
higher_is_better: true higher_is_better: true
- metric: acc_mutual_info - metric: acc_mutual_info
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
\ No newline at end of file
...@@ -17,9 +17,10 @@ model's sample/generation function. ...@@ -17,9 +17,10 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math Homepage: https://github.com/openai/grade-school-math
""" """
import re import re
from lm_eval import utils
from lm_eval.api.task import Task
from lm_eval.api.metrics import mean from lm_eval.api.metrics import mean
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.task import Task
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
...@@ -88,7 +89,13 @@ class GradeSchoolMath8K(Task): ...@@ -88,7 +89,13 @@ class GradeSchoolMath8K(Task):
""" """
# NOTE: The paper implements "verifiers" that assign a score to multiple # NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution. # solutions and output the highest ranked solution.
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, ["\n"]), idx=0, **kwargs) return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, ["\n"]),
idx=0,
**kwargs
)
# completion = rf.greedy_until(ctx, ["\n"]) # completion = rf.greedy_until(ctx, ["\n"])
# return completion # return completion
......
...@@ -60,11 +60,18 @@ class LambadaBase(Task): ...@@ -60,11 +60,18 @@ class LambadaBase(Task):
return " " + doc["text"].rsplit(" ", 1)[1] return " " + doc["text"].rsplit(" ", 1)[1]
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, self.doc_to_target(doc)), **kwargs) return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, self.doc_to_target(doc)),
**kwargs
)
def process_results(self, doc, results): def process_results(self, doc, results):
# TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score # TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score
results = results[0] # TODO: recheck this. currently a list of [(ll, is_greedy)] is passed in results = results[
0
] # TODO: recheck this. currently a list of [(ll, is_greedy)] is passed in
ll, is_greedy = results ll, is_greedy = results
return {"ppl": ll, "acc": int(is_greedy)} return {"ppl": ll, "acc": int(is_greedy)}
......
...@@ -35,7 +35,7 @@ class PilePerplexityTask(PerplexityTask): ...@@ -35,7 +35,7 @@ class PilePerplexityTask(PerplexityTask):
def test_docs(self): def test_docs(self):
for doc in self.dataset["train"].select(range(100)): for doc in self.dataset["train"].select(range(100)):
yield doc yield doc
def has_validation_docs(self): def has_validation_docs(self):
return False return False
...@@ -140,4 +140,4 @@ class PileWikipedia(PilePerplexityTask): ...@@ -140,4 +140,4 @@ class PileWikipedia(PilePerplexityTask):
class PileYoutubeSubtitles(PilePerplexityTask): class PileYoutubeSubtitles(PilePerplexityTask):
DATASET_NAME = "pile_youtubesubtitles" DATASET_NAME = "pile_youtubesubtitles"
\ No newline at end of file
...@@ -37,4 +37,4 @@ metric_list: ...@@ -37,4 +37,4 @@ metric_list:
higher_is_better: false higher_is_better: false
- metric: bits_per_byte - metric: bits_per_byte
aggregation: bits_per_byte aggregation: bits_per_byte
higher_is_better: false higher_is_better: false
\ No newline at end of file
import re import re
def doc_to_text(x):
def doc_to_text(x):
def _mark_span(text, span_str, span_idx, mark): def _mark_span(text, span_str, span_idx, mark):
pattern_tmpl = r'^((?:\S+\s){N})(W)' pattern_tmpl = r"^((?:\S+\s){N})(W)"
pattern = re.sub('N', str(span_idx), pattern_tmpl) pattern = re.sub("N", str(span_idx), pattern_tmpl)
pattern = re.sub('W', span_str, pattern) pattern = re.sub("W", span_str, pattern)
return re.sub(pattern, r'\1{0} \2 {0}'.format(mark), text) return re.sub(pattern, r"\1{0} \2 {0}".format(mark), text)
text = x['text'] text = x["text"]
text = _mark_span(text, x['span1_text'], x['span1_index'], '*') text = _mark_span(text, x["span1_text"], x["span1_index"], "*")
# Compensate for 2 added "words" added in previous step. # Compensate for 2 added "words" added in previous step.
span2_index = x['span2_index'] + 2 * (x['span1_index'] < x['span2_index']) span2_index = x["span2_index"] + 2 * (x["span1_index"] < x["span2_index"])
text = _mark_span(text, x['span2_text'], span2_index, '#') text = _mark_span(text, x["span2_text"], span2_index, "#")
return text return text
\ No newline at end of file
...@@ -60,6 +60,7 @@ def wikitext_detokenizer(string): ...@@ -60,6 +60,7 @@ def wikitext_detokenizer(string):
return string return string
@register_task("wikitext") @register_task("wikitext")
class WikiText(PerplexityTask): class WikiText(PerplexityTask):
VERSION = "2.0" VERSION = "2.0"
......
...@@ -150,7 +150,6 @@ class Reorderer: ...@@ -150,7 +150,6 @@ class Reorderer:
return res return res
def make_table(result_dict): def make_table(result_dict):
"""Generate table of results.""" """Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter from pytablewriter import MarkdownTableWriter, LatexTableWriter
...@@ -262,7 +261,7 @@ def import_function(loader, node): ...@@ -262,7 +261,7 @@ def import_function(loader, node):
function_name = loader.construct_scalar(node) function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name) yaml_path = os.path.dirname(loader.name)
module_name, function_name = function_name.split('.') module_name, function_name = function_name.split(".")
module_path = os.path.join(yaml_path, "{}.py".format(module_name)) module_path = os.path.join(yaml_path, "{}.py".format(module_name))
spec = importlib.util.spec_from_file_location(module_name, module_path) spec = importlib.util.spec_from_file_location(module_name, module_path)
...@@ -272,29 +271,30 @@ def import_function(loader, node): ...@@ -272,29 +271,30 @@ def import_function(loader, node):
function = getattr(module, function_name) function = getattr(module, function_name)
return function return function
# Add the import_function constructor to the YAML loader # Add the import_function constructor to the YAML loader
yaml.add_constructor('!function', import_function) yaml.add_constructor("!function", import_function)
def load_yaml_config(yaml_path): def load_yaml_config(yaml_path):
with open(yaml_path, 'rb') as file: with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file) yaml_config = yaml.full_load(file)
yaml_dir = os.path.dirname(yaml_path) yaml_dir = os.path.dirname(yaml_path)
if 'include' in yaml_config: if "include" in yaml_config:
include_path = yaml_config['include'] include_path = yaml_config["include"]
del yaml_config['include'] del yaml_config["include"]
if type(include_path) == str: if type(include_path) == str:
include_path = [include_path] include_path = [include_path]
# Load from the last one first # Load from the last one first
include_path.reverse() include_path.reverse()
final_yaml_config = {} final_yaml_config = {}
for path in include_path: for path in include_path:
# Assumes that path is a full path. # Assumes that path is a full path.
# If not found, assume the included yaml # If not found, assume the included yaml
# is in the same dir as the original yaml # is in the same dir as the original yaml
if not os.path.isfile(path): if not os.path.isfile(path):
path = os.path.join(yaml_dir, path) path = os.path.join(yaml_dir, path)
...@@ -302,9 +302,9 @@ def load_yaml_config(yaml_path): ...@@ -302,9 +302,9 @@ def load_yaml_config(yaml_path):
try: try:
included_yaml_config = load_yaml_config(path) included_yaml_config = load_yaml_config(path)
final_yaml_config.update(included_yaml_config) final_yaml_config.update(included_yaml_config)
except: except Exception as ex:
# If failed to load, ignore # If failed to load, ignore
pass raise ex
final_yaml_config.update(yaml_config) final_yaml_config.update(yaml_config)
return final_yaml_config return final_yaml_config
...@@ -313,7 +313,7 @@ def load_yaml_config(yaml_path): ...@@ -313,7 +313,7 @@ def load_yaml_config(yaml_path):
env = Environment(loader=BaseLoader, undefined=StrictUndefined) env = Environment(loader=BaseLoader, undefined=StrictUndefined)
def apply_template(template, doc): def apply_template(template, doc):
rtemplate = env.from_string(template) rtemplate = env.from_string(template)
return rtemplate.render(**doc) return rtemplate.render(**doc)
...@@ -7,7 +7,8 @@ from lm_eval import evaluator, utils ...@@ -7,7 +7,8 @@ from lm_eval import evaluator, utils
from lm_eval.tasks import ALL_TASKS from lm_eval.tasks import ALL_TASKS
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
os.environ['TOKENIZERS_PARALLELISM'] = 'false' os.environ["TOKENIZERS_PARALLELISM"] = "false"
class MultiChoice: class MultiChoice:
def __init__(self, choices): def __init__(self, choices):
...@@ -65,9 +66,10 @@ def main(): ...@@ -65,9 +66,10 @@ def main():
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
) )
if args.tasks != None: if args.tasks is not None:
if os.path.isdir(args.tasks): if os.path.isdir(args.tasks):
import glob import glob
task_names = [] task_names = []
yaml_path = os.path.join(args.tasks, "*.yaml") yaml_path = os.path.join(args.tasks, "*.yaml")
for yaml_file in glob.glob(yaml_path): for yaml_file in glob.glob(yaml_path):
...@@ -80,7 +82,7 @@ def main(): ...@@ -80,7 +82,7 @@ def main():
if os.path.isfile(task): if os.path.isfile(task):
config = utils.load_yaml_config(task) config = utils.load_yaml_config(task)
task_names.append(config) task_names.append(config)
eval_logger.info(f"Selected Tasks: {task_names}") eval_logger.info(f"Selected Tasks: {task_names}")
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
......
...@@ -42,6 +42,6 @@ setuptools.setup( ...@@ -42,6 +42,6 @@ setuptools.setup(
extras_require={ extras_require={
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"], "dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"], "multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"] "sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
}, },
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment