Unverified Commit 57b20eef authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #857 from chrisociepa/python_3_8_compatible

Use Dict to make the code python 3.8 compatible
parents 5d4ac134 38c9689a
import ast import ast
from typing import Dict
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
...@@ -7,7 +8,7 @@ from lm_eval.logger import eval_logger ...@@ -7,7 +8,7 @@ from lm_eval.logger import eval_logger
# Stores prompts in a dictionary indexed by 2 levels: # Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name. # prompt category name, and prompt name.
# This allows us to access prompts # This allows us to access prompts
PROMPT_REGISTRY: dict[str, dict[str, str]] = { PROMPT_REGISTRY: Dict[str, Dict[str, str]] = {
"qa-basic": { "qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:", "question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {{question}}\nA:", "q-newline-a": "Q: {{question}}\nA:",
......
import os import os
import yaml import yaml
from typing import List, Union from typing import List, Union, Dict
from lm_eval import utils from lm_eval import utils
from lm_eval import prompts from lm_eval import prompts
...@@ -15,7 +15,7 @@ from lm_eval.api.registry import ( ...@@ -15,7 +15,7 @@ from lm_eval.api.registry import (
) )
def register_configurable_task(config: dict[str, str]) -> int: def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type( SubClass = type(
config["task"] + "ConfigurableTask", config["task"] + "ConfigurableTask",
(ConfigurableTask,), (ConfigurableTask,),
...@@ -38,7 +38,7 @@ def register_configurable_task(config: dict[str, str]) -> int: ...@@ -38,7 +38,7 @@ def register_configurable_task(config: dict[str, str]) -> int:
return 0 return 0
def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]: def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]:
all_configs = [] all_configs = []
if "use_prompt" in config: if "use_prompt" in config:
prompt_list = prompts.load_prompt_list( prompt_list = prompts.load_prompt_list(
...@@ -69,7 +69,7 @@ def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]: ...@@ -69,7 +69,7 @@ def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]:
return all_configs return all_configs
def get_task_name_from_config(task_config: dict[str, str]) -> str: def get_task_name_from_config(task_config: Dict[str, str]) -> str:
if "dataset_name" in task_config: if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config) return "{dataset_path}_{dataset_name}".format(**task_config)
else: else:
...@@ -128,7 +128,7 @@ def get_task_name_from_object(task_object): ...@@ -128,7 +128,7 @@ def get_task_name_from_object(task_object):
# TODO: pass num_fewshot and other cmdline overrides in a better way # TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
config = {**kwargs} config = {**kwargs}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment