Commit 90ad5db7 authored by lintangsutawika's avatar lintangsutawika
Browse files

merged main

parents f692caa9 b177c82c
......@@ -28,7 +28,7 @@ class OptimumLM(HFLM):
super().__init__(
device=self.openvino_device,
backend=kwargs.get("backend", "causal"),
backend=kwargs.pop("backend", "causal"),
**kwargs,
)
......
......@@ -6,6 +6,7 @@ from functools import wraps
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
......@@ -357,65 +358,164 @@ class Collator:
A class for reordering and batching elements of an array.
This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
Objects of this class have the group_by attribute which determines the method for grouping
the data while batching it. Three options include "gen_kwargs", "contexts", or None:
If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs
If group_by == "contexts" then requests will be grouped by context + cont[:-1]
If None then requests will just be reordered by length descending.
"""
def __init__(
self,
arr: List,
sort_fn: Callable,
sort_fn: Callable = lambda x: x,
group_fn: Callable = lambda x: x[1],
grouping: bool = False,
group_by: Union[Literal["gen_kwargs", "contexts"], None] = None,
) -> None:
self.grouping = grouping
self.fn = sort_fn
self.group_fn = lambda x: group_fn(x[1]) # first index are enumerated indices
self.reorder_indices: List = []
self.size = len(arr)
self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr)) # [indices, (arr)]
if self.grouping is True:
self.group_by_index()
self._group_by = group_by
# 0 indices are enumerated indices. Apply functions to original arr.
self._sort_fn = lambda x: sort_fn(x[1])
self._group_fn = lambda x: group_fn(x[1])
self._reorder_indices: List = []
self._size = len(arr)
self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple(
enumerate(arr)
) # [indices, (arr)]
if self._group_by == "contexts":
self._group_by_context()
elif self._group_by == "gen_kwargs":
self._group_by_index()
def _group_by_index(self) -> None:
"""Group the elements of a list based on their indices."""
self._arr_with_indices = self.group(
self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs"
)
def group_by_index(self) -> None:
self.arr_with_indices = self.group(
self.arr_with_indices, fn=self.group_fn, values=False
def _group_by_context(self) -> None:
"""Group the array with indices by context."""
self._arr_with_indices = self.group(
self._arr_with_indices, fn=self._group_fn, group_by="contexts"
)
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
"""
Generates and yields batches from the reordered array.
Generates and yields batches from the reordered array. The method of grouping and batching
depends on the parameter `group_by`.
If `group_by` is set to "gen_kwargs", it will batch the
re-ordered values with same gen_kwargs for each batch.
If `group_by` is "contexts", it caches the requests by context before batching.
If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array
Parameters:
- n (int): The size of each batch. Defaults to 1.
- batch_fn (Optional[Callable[[int, Iterable], int]]): A function to determine the size of each batch. Defaults to None.
- batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
each batch. Optional, defaults to None.
Returns:
Iterator: An iterator over batches of reordered elements grouped as per the `group_by`
attribute.
Yields:
Iterator: An iterator over batches of reordered elements.
List of batched elements according to the `group_by` attribute.
"""
if self.grouping:
if self._group_by == "gen_kwargs":
for (
key,
values,
) in self.arr_with_indices.items(): # type: ignore
) in self._arr_with_indices.items(): # type: ignore
values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
elif self._group_by == "contexts":
# Get one sample from each key
values = self._reorder(
[value[0] for value in self._arr_with_indices.values()]
)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
else:
values = self._reorder(self.arr_with_indices) # type: ignore
values = self._reorder(self._arr_with_indices) # type: ignore
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List:
def get_cache(
self,
req_str: Tuple[str, str] = None,
cxt_toks: List[int] = None,
cont_toks: List[int] = None,
logits: torch.Tensor = None,
) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]:
"""
Retrieves cached single-token continuations and their associated arguments, updating indices as necessary.
The behavior of this function varies depending on how the `group_by` attribute is set:
- When `group_by` is "contexts":
The function identifies single-token continuations by checking for keys that equate to
[context+continuation][-1] and logs the indices for re-ordering.
In this mode, this function can work in two scenarios:
1. Cache Hit - Single Match:
If a single matching context-continuation pair is found in the cache,
the function yields the original arguments.
2. Cache Hit - Multiple Matches:
If multiple matching context-continuation pairs are found in the cache,
the function expands the logits batch dimension to match the number of cache hits.
It updates the original requests and continuation tokens.
- When `group_by` is not set to "contexts":
This method yields the original arguments, logits and continuation tokens,
without checking for one-token continuations.
Parameters:
- req_str (tuple[str, str]): Original strings used for CachingLM.
- cxt_toks (list[int]): Full context tokens used for lookup.
- cont_toks (list[int]): Continuation tokens for which logits were generated.
- logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys.
Yields:
- Iterator:
- req_str (tuple[str, str]): strings used for CachingLM.
- cont_toks (list[int]) : continuation tokens.
- logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times)
"""
if self._group_by == "contexts":
cache_hit: List[
Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]]
] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1]))
if (cache_size := len(cache_hit)) == 1:
self._reorder_indices.extend(x[0] for x in cache_hit)
yield req_str, cont_toks, logits
else:
# If we have matching requests then expand the batch dimension (no-op) and
# yield each along with its corresponding args.
multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size)
indices, req_str, cont_toks = zip(
*[(x[0], x[1][0], x[-1][-1]) for x in cache_hit]
)
self._reorder_indices.extend(indices)
for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits):
yield c_key, cont_tok, logit
else:
yield req_str, cont_toks, logits
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator:
"""
Reorders the elements in the array based on the sorting function.
Parameters:
- arr (Union[List, Tuple[Tuple[int, Any], ...]]): The array or iterable to be reordered.
- arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered.
Yields:
List: Yields reordered elements one by one.
Iterator
"""
arr = sorted(arr, key=lambda x: self.fn(x[1]))
self.reorder_indices.extend([x[0] for x in arr])
arr = sorted(arr, key=self._sort_fn)
if not self._group_by == "contexts":
# If grouped by contexts then indices will be set in get_cache()
self._reorder_indices.extend([x[0] for x in arr])
yield from [x[1] for x in arr]
def get_original(self, newarr: List) -> List:
......@@ -423,15 +523,15 @@ class Collator:
Restores the original order of elements from the reordered list.
Parameters:
- newarr (List): The reordered array.
- newarr (list): The reordered array.
Returns:
List: The array with elements restored to their original order.
list: The array with elements restored to their original order.
"""
res = [None] * self.size
cov = [False] * self.size
res = [None] * self._size
cov = [False] * self._size
for ind, v in zip(self.reorder_indices, newarr):
for ind, v in zip(self._reorder_indices, newarr):
res[ind] = v
cov[ind] = True
......@@ -440,39 +540,50 @@ class Collator:
return res
def __len__(self):
return self.size
return self._size
@staticmethod
def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
def group(
arr: Iterable,
fn: Callable,
group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs",
) -> dict:
"""
Groups elements of an iterable based on a provided function.
The `group_by` parameter determines the method of grouping.
If `group_by` is "contexts", the elements are grouped by [context + cont][:-1].
If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict.
Parameters:
- arr (Iterable): The iterable to be grouped.
- fn (Callable): The function to determine the grouping.
- values (bool): If True, returns the values of the group. Defaults to False.
Returns:
Iterable: An iterable of grouped elements.
Iterator: An iterable of grouped elements.
"""
res = collections.defaultdict(list)
for ob in arr:
try:
hashable_dict = tuple(
(
key,
tuple(value)
if isinstance(value, collections.abc.Iterable)
else value,
# where ob == [context + cont]
if group_by == "contexts":
res[tuple(fn(ob))].append(ob)
else:
try:
hashable_dict = tuple(
(
key,
tuple(value)
if isinstance(value, collections.abc.Iterable)
else value,
)
for key, value in sorted(fn(ob).items())
)
for key, value in sorted(fn(ob).items())
)
res[hashable_dict].append(ob)
except TypeError:
res[fn(ob)].append(ob)
if not values:
return res
return res.values()
res[hashable_dict].append(ob)
except (TypeError, AttributeError):
res[tuple(fn(ob))].append(ob)
return res
@staticmethod
def get_chunks(_iter, n: int = 0, fn=None):
......
......@@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Tuple, Union
from tqdm import tqdm
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import Collator, divide
from lm_eval.utils import (
......@@ -35,7 +35,7 @@ def run_inference_one_model(
@register_model("vllm")
class VLLM(LM):
class VLLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048
def __init__(
......@@ -47,6 +47,7 @@ class VLLM(LM):
tokenizer: Optional[str] = None,
tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None,
add_bos_token: Optional[bool] = False,
tensor_parallel_size: int = 1,
quantization: Optional[str] = None,
max_gen_toks: int = 256,
......@@ -114,6 +115,7 @@ class VLLM(LM):
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
)
self.add_bos_token = add_bos_token
self._max_gen_toks = max_gen_toks
......@@ -147,10 +149,12 @@ class VLLM(LM):
self,
string: str,
left_truncate_len=None,
add_special_tokens=False,
add_special_tokens=None,
truncation=False,
):
""" """
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding = self.tokenizer.encode(
string, add_special_tokens=add_special_tokens, truncation=truncation
)
......@@ -194,37 +198,6 @@ class VLLM(LM):
)
return outputs
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = []
......@@ -276,12 +249,16 @@ class VLLM(LM):
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator(requests, _collate_gen, grouping=True)
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
)
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
pbar = tqdm(
total=len(requests),
disable=(self.rank != 0),
desc="Running generate_until requests",
)
# for each different set of kwargs, we execute all requests, by batch.
for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk)
......@@ -356,7 +333,11 @@ class VLLM(LM):
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
)
pbar = tqdm(total=len(requests), disable=disable_tqdm)
pbar = tqdm(
total=len(requests),
disable=disable_tqdm,
desc="Running loglikelihood requests",
)
for chunk in chunks:
inputs = []
ctxlens = []
......
import os
import ast
import os
from typing import Dict
from lm_eval import utils
from lm_eval.utils import eval_logger
# Prompt library.
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
......
import os
import abc
import collections
import logging
import os
from functools import partial
from typing import List, Union, Dict
from typing import Dict, List, Union
from lm_eval import utils
from lm_eval.api.task import Task, ConfigurableTask
import logging
from lm_eval.api.task import ConfigurableTask, Task
class TaskManager:
......@@ -16,20 +14,14 @@ class TaskManager:
and an optional directory if provided.
"""
def __init__(
self,
verbosity="INFO",
include_path=None
) -> None:
def __init__(self, verbosity="INFO", include_path=None) -> None:
self.verbosity = verbosity
self.include_path = include_path
self.logger = utils.eval_logger
self.logger.setLevel(getattr(logging, f"{verbosity}"))
self._task_index = self.initialize_tasks(
include_path=include_path
)
self._task_index = self.initialize_tasks(include_path=include_path)
self._all_tasks = sorted(list(self._task_index.keys()))
self.task_group_map = collections.defaultdict(list)
......@@ -65,27 +57,29 @@ class TaskManager:
return self._task_index
def match_tasks(self, task_list):
return utils.pattern_match(
task_list, self.all_tasks
)
return utils.pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name):
if name in self.all_tasks:
return True
return False
def _name_is_task(self, name):
def _name_is_task(self, name) -> bool:
if self._name_is_registered(name) and ("task" in self.task_index[name]["type"]):
return True
return False
def _name_is_group(self, name):
if self._name_is_registered(name) and (self.task_index[name]["type"] == "group"):
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "group"
):
return True
return False
def _name_is_python_task(self, name):
if self._name_is_registered(name) and (self.task_index[name]["type"] == "python_task"):
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "python_task"
):
return True
return False
......@@ -117,7 +111,7 @@ class TaskManager:
return utils.load_yaml_config(yaml_path, mode="full")
def _get_tasklist(self, name):
assert self._name_is_task(name) == False
assert self._name_is_task(name) is False
return self.task_index[name]["task"]
def _process_alias(self, config, group=None):
......@@ -130,12 +124,12 @@ class TaskManager:
return config
def _load_individual_task_or_group(
self,
name_or_config: Union[str, dict] = None,
parent_name: str = None,
update_config: dict = None,
yaml_path: str = None,
) -> ConfigurableTask:
self,
name_or_config: Union[str, dict] = None,
parent_name: str = None,
update_config: dict = None,
yaml_path: str = None,
) -> ConfigurableTask:
def load_task(config, task, group=None, yaml_path=None):
if "include" in config:
assert yaml_path is not None
......@@ -174,7 +168,9 @@ class TaskManager:
group_config = self._get_config(name_or_config)
if set(group_config.keys()) > set(["task", "group"]):
update_config = {
k:v for k,v in group_config.items() if k not in ["task", "group"]
k: v
for k, v in group_config.items()
if k not in ["task", "group"]
}
yaml_path = self._get_yaml_path(group_name)
......@@ -183,9 +179,8 @@ class TaskManager:
update_config.pop("group_alias")
if isinstance(name_or_config, dict):
if update_config is not None:
name_or_config={
name_or_config = {
**name_or_config,
**update_config,
}
......@@ -196,7 +191,9 @@ class TaskManager:
# if self._name_is_task(name) is False:
if self._name_is_group(name):
group_name = name
update_config = {k:v for k,v in name_or_config.items() if k != "task"}
update_config = {
k: v for k, v in name_or_config.items() if k != "task"
}
subtask_list = self._get_tasklist(name)
if subtask_list == -1:
subtask_list = self._get_config(name)["task"]
......@@ -207,36 +204,53 @@ class TaskManager:
# Check if this is a duplicate.
if parent_name is not None:
name_or_config["group"] = parent_name
num_duplicate = len(list(filter(lambda x: x.startswith(name), self.task_group_map[parent_name])))
num_duplicate = len(
list(
filter(
lambda x: x.startswith(name),
self.task_group_map[parent_name],
)
)
)
if num_duplicate > 0:
name = f"{name}-{num_duplicate}"
self.task_group_map[parent_name].append(name)
task_config={
**base_task_config,
**name_or_config,
}
task_config = {
**base_task_config,
**name_or_config,
}
else:
task_config = name_or_config
return load_task(task_config, task=name, group=parent_name, yaml_path=yaml_path)
return load_task(
task_config, task=name, group=parent_name, yaml_path=yaml_path
)
else:
group_name = name_or_config["group"]
subtask_list = name_or_config["task"]
# update_config = {k:v for k,v in name_or_config.items() if k != "task"}
if set(name_or_config.keys()) > set(["task", "group"]):
update_config = {
k:v for k,v in name_or_config.items() if k not in ["task", "group"]
k: v
for k, v in name_or_config.items()
if k not in ["task", "group"]
}
all_subtasks = {}
if (parent_name is not None):
if parent_name is not None:
all_subtasks = {group_name: (parent_name, None)}
fn = partial(self._load_individual_task_or_group, parent_name=group_name, update_config=update_config, yaml_path=yaml_path)
all_subtasks = {**all_subtasks, **dict(collections.ChainMap(*map(fn, subtask_list)))}
fn = partial(
self._load_individual_task_or_group,
parent_name=group_name,
update_config=update_config,
yaml_path=yaml_path,
)
all_subtasks = {
**all_subtasks,
**dict(collections.ChainMap(*map(fn, subtask_list))),
}
return all_subtasks
def load_task_or_group(self, task_list: Union[str, list] = None) -> dict:
"""Loads a dictionary of task objects from a list
......@@ -250,12 +264,7 @@ class TaskManager:
task_list = [task_list]
all_loaded_tasks = dict(
collections.ChainMap(
*map(
self._load_individual_task_or_group,
task_list
)
)
collections.ChainMap(*map(self._load_individual_task_or_group, task_list))
)
return all_loaded_tasks
......@@ -299,11 +308,11 @@ class TaskManager:
# This is a group config
tasks_and_groups[config["group"]] = {
"type": "group",
"task": -1, # This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"task": -1, # This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"yaml_path": yaml_path,
}
......@@ -322,7 +331,7 @@ class TaskManager:
tasks_and_groups[task] = {
"type": "task",
"yaml_path": yaml_path,
}
}
if "group" in config:
groups = config["group"]
......@@ -343,6 +352,7 @@ class TaskManager:
return tasks_and_groups
def include_path(task_dir):
logger = utils.eval_logger
logger.setLevel(getattr(logging, "INFO"))
......@@ -352,6 +362,7 @@ def include_path(task_dir):
)
return 0
def initialize_tasks(verbosity="INFO"):
logger = utils.eval_logger
logger.setLevel(getattr(logging, f"{verbosity}"))
......@@ -362,6 +373,7 @@ def initialize_tasks(verbosity="INFO"):
)
return 0
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
if "task" in task_config:
return task_config["task"]
......@@ -370,6 +382,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
else:
return "{dataset_path}".format(**task_config)
def get_task_name_from_object(task_object):
if hasattr(task_object, "config"):
return task_object._config["task"]
......@@ -382,7 +395,10 @@ def get_task_name_from_object(task_object):
else type(task_object).__name__
)
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: TaskManager = None):
def get_task_dict(
task_name_list: List[Union[str, Dict, Task]], task_manager: TaskManager = None
):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]]
......@@ -409,7 +425,9 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: Ta
if task_manager is None:
task_manager = TaskManager()
task_name_from_string_dict = task_manager.load_task_or_group(string_task_name_list)
task_name_from_string_dict = task_manager.load_task_or_group(
string_task_name_list
)
for task_element in others_task_name_list:
if isinstance(task_element, dict):
......@@ -427,6 +445,7 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: Ta
assert set(task_name_from_string_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys())
)
return {
**task_name_from_string_dict,
**task_name_from_config_dict,
......
# ArabicMMLU
### Paper
ArabicMMLU: Measuring massive multitask language understanding in Arabic
This dataset has been translated from the original MMLU with the help of GPT-4.
The original data [MMLU](https://arxiv.org/pdf/2009.03300v3.pdf)
The translation has been done with AceGPT researchers [AceGPT](https://arxiv.org/abs/2309.12053)
ArabicMMLU is a comprehensive evaluation benchmark specifically designed to evaluate the knowledge and reasoning abilities of LLMs within the context of Arabic language and culture.
ArabicMMLU covers a wide range of subjects, comprising 57 topics that span from elementary to advanced professional levels.
Homepage: [AceGPT Homepage](https://github.com/FreedomIntelligence/AceGPT/tree/main/eval/benchmark_eval/benchmarks/MMLUArabic)
### Citation
### Groups and Tasks
#### Groups
- `ammlu`: All 57 subjects of the ArabicMMLU dataset, evaluated following the methodology in MMLU's original implementation.
#### Tasks
The following tasks evaluate subjects in the ArabicMMLU dataset using loglikelihood-based multiple-choice scoring:
- `ammlu_{subject_english}`
### Checklist
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [x] If yes, does the original paper provide a reference implementation?
* [x] Yes, original implementation contributed by author of the benchmark
If other tasks on this dataset are already supported:
* [x] Is the "Main" variant of this task clearly denoted?
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
group: ammlu
dataset_path: Hennara/ammlu
test_split: test
fewshot_split: dev
fewshot_config:
sampler: first_n
output_type: multiple_choice
doc_to_text: "{{Question.strip()}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nالجواب:"
doc_to_choice: ["A", "B", "C", "D"]
doc_to_target: "{{['A', 'B', 'C', 'D'].index(Answer)}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
metadata:
version: 0.0
"""
Take in a YAML, and output all other splits with this YAML
"""
import argparse
import os
import yaml
from tqdm import tqdm
SUBJECTS = {
"abstract_algebra": "ألعلوم وتقنية المعلومات و الرياضيات",
"anatomy": "ألعلوم وتقنية المعلومات و الرياضيات",
"astronomy": "ألعلوم وتقنية المعلومات و الرياضيات",
"business_ethics": "علوم أخرى",
"clinical_knowledge": "علوم أخرى",
"college_biology": "ألعلوم وتقنية المعلومات و الرياضيات",
"college_chemistry": "ألعلوم وتقنية المعلومات و الرياضيات",
"college_computer_science": "ألعلوم وتقنية المعلومات و الرياضيات",
"college_mathematics": "ألعلوم وتقنية المعلومات و الرياضيات",
"college_medicine": "علوم أخرى",
"college_physics": "ألعلوم وتقنية المعلومات و الرياضيات",
"computer_security": "ألعلوم وتقنية المعلومات و الرياضيات",
"conceptual_physics": "ألعلوم وتقنية المعلومات و الرياضيات",
"econometrics": "العلوم الإجتماعية",
"electrical_engineering": "ألعلوم وتقنية المعلومات و الرياضيات",
"elementary_mathematics": "ألعلوم وتقنية المعلومات و الرياضيات",
"formal_logic": "العلوم الانسانية",
"global_facts": "علوم أخرى",
"high_school_biology": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_chemistry": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_computer_science": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_european_history": "العلوم الانسانية",
"high_school_geography": "العلوم الإجتماعية",
"high_school_government_and_politics": "العلوم الإجتماعية",
"high_school_macroeconomics": "العلوم الإجتماعية",
"high_school_mathematics": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_microeconomics": "العلوم الإجتماعية",
"high_school_physics": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_psychology": "العلوم الإجتماعية",
"high_school_statistics": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_us_history": "العلوم الانسانية",
"high_school_world_history": "العلوم الانسانية",
"human_aging": "علوم أخرى",
"human_sexuality": "العلوم الإجتماعية",
"international_law": "العلوم الانسانية",
"jurisprudence": "العلوم الانسانية",
"logical_fallacies": "العلوم الانسانية",
"machine_learning": "ألعلوم وتقنية المعلومات و الرياضيات",
"management": "علوم أخرى",
"marketing": "علوم أخرى",
"medical_genetics": "علوم أخرى",
"miscellaneous": "علوم أخرى",
"moral_disputes": "العلوم الانسانية",
"moral_scenarios": "العلوم الانسانية",
"nutrition": "علوم أخرى",
"philosophy": "العلوم الانسانية",
"prehistory": "العلوم الانسانية",
"professional_accounting": "علوم أخرى",
"professional_law": "العلوم الانسانية",
"professional_medicine": "علوم أخرى",
"professional_psychology": "العلوم الإجتماعية",
"public_relations": "العلوم الإجتماعية",
"security_studies": "العلوم الإجتماعية",
"sociology": "العلوم الإجتماعية",
"us_foreign_policy": "العلوم الإجتماعية",
"virology": "علوم أخرى",
"world_religions": "العلوم الانسانية",
}
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_yaml_path", required=True)
parser.add_argument("--save_prefix_path", default="ammlu")
parser.add_argument("--cot_prompt_path", default=None)
parser.add_argument("--task_prefix", default="")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
base_yaml_name = os.path.split(args.base_yaml_path)[-1]
with open(args.base_yaml_path, encoding="utf-8") as f:
base_yaml = yaml.full_load(f)
if args.cot_prompt_path is not None:
import json
with open(args.cot_prompt_path, encoding="utf-8") as f:
cot_file = json.load(f)
for subject_eng, category in tqdm(SUBJECTS.items()):
if args.cot_prompt_path is not None:
description = cot_file[subject_eng]
else:
description = f"فم بعملية التقييم في مجال {category} \n\n"
yaml_dict = {
"include": base_yaml_name,
"task": f"ammlu_{args.task_prefix}_{subject_eng}"
if args.task_prefix != ""
else f"ammlu_{subject_eng}",
"dataset_name": subject_eng,
"description": description,
}
file_save_path = args.save_prefix_path + f"_{subject_eng}.yaml"
print(f"Saving yaml for subset {subject_eng} to {file_save_path}")
with open(file_save_path, "w", encoding="utf-8") as yaml_file:
yaml.dump(
yaml_dict,
yaml_file,
width=float("inf"),
allow_unicode=True,
default_style='"',
)
"dataset_name": "abstract_algebra"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_abstract_algebra"
"dataset_name": "anatomy"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_anatomy"
"dataset_name": "astronomy"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_astronomy"
"dataset_name": "business_ethics"
"description": "فم بعملية التقييم في مجال علوم أخرى \n\n"
"include": "_default_template_yaml"
"task": "ammlu_business_ethics"
"dataset_name": "clinical_knowledge"
"description": "فم بعملية التقييم في مجال علوم أخرى \n\n"
"include": "_default_template_yaml"
"task": "ammlu_clinical_knowledge"
"dataset_name": "college_biology"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_college_biology"
"dataset_name": "college_chemistry"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_college_chemistry"
"dataset_name": "college_computer_science"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_college_computer_science"
"dataset_name": "college_mathematics"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_college_mathematics"
"dataset_name": "college_medicine"
"description": "فم بعملية التقييم في مجال علوم أخرى \n\n"
"include": "_default_template_yaml"
"task": "ammlu_college_medicine"
"dataset_name": "college_physics"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_college_physics"
"dataset_name": "computer_security"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_computer_security"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment