Commit 90ad5db7 authored by lintangsutawika's avatar lintangsutawika
Browse files

merged main

parents f692caa9 b177c82c
...@@ -28,7 +28,7 @@ class OptimumLM(HFLM): ...@@ -28,7 +28,7 @@ class OptimumLM(HFLM):
super().__init__( super().__init__(
device=self.openvino_device, device=self.openvino_device,
backend=kwargs.get("backend", "causal"), backend=kwargs.pop("backend", "causal"),
**kwargs, **kwargs,
) )
......
...@@ -6,6 +6,7 @@ from functools import wraps ...@@ -6,6 +6,7 @@ from functools import wraps
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict,
Iterable, Iterable,
Iterator, Iterator,
List, List,
...@@ -357,65 +358,164 @@ class Collator: ...@@ -357,65 +358,164 @@ class Collator:
A class for reordering and batching elements of an array. A class for reordering and batching elements of an array.
This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data. This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
Objects of this class have the group_by attribute which determines the method for grouping
the data while batching it. Three options include "gen_kwargs", "contexts", or None:
If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs
If group_by == "contexts" then requests will be grouped by context + cont[:-1]
If None then requests will just be reordered by length descending.
""" """
def __init__( def __init__(
self, self,
arr: List, arr: List,
sort_fn: Callable, sort_fn: Callable = lambda x: x,
group_fn: Callable = lambda x: x[1], group_fn: Callable = lambda x: x[1],
grouping: bool = False, group_by: Union[Literal["gen_kwargs", "contexts"], None] = None,
) -> None: ) -> None:
self.grouping = grouping self._group_by = group_by
self.fn = sort_fn # 0 indices are enumerated indices. Apply functions to original arr.
self.group_fn = lambda x: group_fn(x[1]) # first index are enumerated indices self._sort_fn = lambda x: sort_fn(x[1])
self.reorder_indices: List = [] self._group_fn = lambda x: group_fn(x[1])
self.size = len(arr) self._reorder_indices: List = []
self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr)) # [indices, (arr)] self._size = len(arr)
if self.grouping is True: self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple(
self.group_by_index() enumerate(arr)
) # [indices, (arr)]
if self._group_by == "contexts":
self._group_by_context()
elif self._group_by == "gen_kwargs":
self._group_by_index()
def _group_by_index(self) -> None:
"""Group the elements of a list based on their indices."""
self._arr_with_indices = self.group(
self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs"
)
def group_by_index(self) -> None: def _group_by_context(self) -> None:
self.arr_with_indices = self.group( """Group the array with indices by context."""
self.arr_with_indices, fn=self.group_fn, values=False self._arr_with_indices = self.group(
self._arr_with_indices, fn=self._group_fn, group_by="contexts"
) )
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator: def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
""" """
Generates and yields batches from the reordered array. Generates and yields batches from the reordered array. The method of grouping and batching
depends on the parameter `group_by`.
If `group_by` is set to "gen_kwargs", it will batch the
re-ordered values with same gen_kwargs for each batch.
If `group_by` is "contexts", it caches the requests by context before batching.
If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array
Parameters: Parameters:
- n (int): The size of each batch. Defaults to 1. - n (int): The size of each batch. Defaults to 1.
- batch_fn (Optional[Callable[[int, Iterable], int]]): A function to determine the size of each batch. Defaults to None. - batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
each batch. Optional, defaults to None.
Returns:
Iterator: An iterator over batches of reordered elements grouped as per the `group_by`
attribute.
Yields: Yields:
Iterator: An iterator over batches of reordered elements. List of batched elements according to the `group_by` attribute.
""" """
if self.grouping: if self._group_by == "gen_kwargs":
for ( for (
key, key,
values, values,
) in self.arr_with_indices.items(): # type: ignore ) in self._arr_with_indices.items(): # type: ignore
values = self._reorder(values) values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn) batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch yield from batch
elif self._group_by == "contexts":
# Get one sample from each key
values = self._reorder(
[value[0] for value in self._arr_with_indices.values()]
)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
else: else:
values = self._reorder(self.arr_with_indices) # type: ignore values = self._reorder(self._arr_with_indices) # type: ignore
batch = self.get_chunks(values, n=n, fn=batch_fn) batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch yield from batch
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List: def get_cache(
self,
req_str: Tuple[str, str] = None,
cxt_toks: List[int] = None,
cont_toks: List[int] = None,
logits: torch.Tensor = None,
) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]:
"""
Retrieves cached single-token continuations and their associated arguments, updating indices as necessary.
The behavior of this function varies depending on how the `group_by` attribute is set:
- When `group_by` is "contexts":
The function identifies single-token continuations by checking for keys that equate to
[context+continuation][-1] and logs the indices for re-ordering.
In this mode, this function can work in two scenarios:
1. Cache Hit - Single Match:
If a single matching context-continuation pair is found in the cache,
the function yields the original arguments.
2. Cache Hit - Multiple Matches:
If multiple matching context-continuation pairs are found in the cache,
the function expands the logits batch dimension to match the number of cache hits.
It updates the original requests and continuation tokens.
- When `group_by` is not set to "contexts":
This method yields the original arguments, logits and continuation tokens,
without checking for one-token continuations.
Parameters:
- req_str (tuple[str, str]): Original strings used for CachingLM.
- cxt_toks (list[int]): Full context tokens used for lookup.
- cont_toks (list[int]): Continuation tokens for which logits were generated.
- logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys.
Yields:
- Iterator:
- req_str (tuple[str, str]): strings used for CachingLM.
- cont_toks (list[int]) : continuation tokens.
- logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times)
"""
if self._group_by == "contexts":
cache_hit: List[
Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]]
] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1]))
if (cache_size := len(cache_hit)) == 1:
self._reorder_indices.extend(x[0] for x in cache_hit)
yield req_str, cont_toks, logits
else:
# If we have matching requests then expand the batch dimension (no-op) and
# yield each along with its corresponding args.
multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size)
indices, req_str, cont_toks = zip(
*[(x[0], x[1][0], x[-1][-1]) for x in cache_hit]
)
self._reorder_indices.extend(indices)
for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits):
yield c_key, cont_tok, logit
else:
yield req_str, cont_toks, logits
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator:
""" """
Reorders the elements in the array based on the sorting function. Reorders the elements in the array based on the sorting function.
Parameters: Parameters:
- arr (Union[List, Tuple[Tuple[int, Any], ...]]): The array or iterable to be reordered. - arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered.
Yields: Yields:
List: Yields reordered elements one by one. Iterator
""" """
arr = sorted(arr, key=lambda x: self.fn(x[1])) arr = sorted(arr, key=self._sort_fn)
self.reorder_indices.extend([x[0] for x in arr]) if not self._group_by == "contexts":
# If grouped by contexts then indices will be set in get_cache()
self._reorder_indices.extend([x[0] for x in arr])
yield from [x[1] for x in arr] yield from [x[1] for x in arr]
def get_original(self, newarr: List) -> List: def get_original(self, newarr: List) -> List:
...@@ -423,15 +523,15 @@ class Collator: ...@@ -423,15 +523,15 @@ class Collator:
Restores the original order of elements from the reordered list. Restores the original order of elements from the reordered list.
Parameters: Parameters:
- newarr (List): The reordered array. - newarr (list): The reordered array.
Returns: Returns:
List: The array with elements restored to their original order. list: The array with elements restored to their original order.
""" """
res = [None] * self.size res = [None] * self._size
cov = [False] * self.size cov = [False] * self._size
for ind, v in zip(self.reorder_indices, newarr): for ind, v in zip(self._reorder_indices, newarr):
res[ind] = v res[ind] = v
cov[ind] = True cov[ind] = True
...@@ -440,39 +540,50 @@ class Collator: ...@@ -440,39 +540,50 @@ class Collator:
return res return res
def __len__(self): def __len__(self):
return self.size return self._size
@staticmethod @staticmethod
def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable: def group(
arr: Iterable,
fn: Callable,
group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs",
) -> dict:
""" """
Groups elements of an iterable based on a provided function. Groups elements of an iterable based on a provided function.
The `group_by` parameter determines the method of grouping.
If `group_by` is "contexts", the elements are grouped by [context + cont][:-1].
If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict.
Parameters: Parameters:
- arr (Iterable): The iterable to be grouped. - arr (Iterable): The iterable to be grouped.
- fn (Callable): The function to determine the grouping. - fn (Callable): The function to determine the grouping.
- values (bool): If True, returns the values of the group. Defaults to False. - values (bool): If True, returns the values of the group. Defaults to False.
Returns: Returns:
Iterable: An iterable of grouped elements. Iterator: An iterable of grouped elements.
""" """
res = collections.defaultdict(list) res = collections.defaultdict(list)
for ob in arr: for ob in arr:
try: # where ob == [context + cont]
hashable_dict = tuple( if group_by == "contexts":
( res[tuple(fn(ob))].append(ob)
key, else:
tuple(value) try:
if isinstance(value, collections.abc.Iterable) hashable_dict = tuple(
else value, (
key,
tuple(value)
if isinstance(value, collections.abc.Iterable)
else value,
)
for key, value in sorted(fn(ob).items())
) )
for key, value in sorted(fn(ob).items()) res[hashable_dict].append(ob)
) except (TypeError, AttributeError):
res[hashable_dict].append(ob) res[tuple(fn(ob))].append(ob)
except TypeError: return res
res[fn(ob)].append(ob)
if not values:
return res
return res.values()
@staticmethod @staticmethod
def get_chunks(_iter, n: int = 0, fn=None): def get_chunks(_iter, n: int = 0, fn=None):
......
...@@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Tuple, Union ...@@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Tuple, Union
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.utils import Collator, divide from lm_eval.models.utils import Collator, divide
from lm_eval.utils import ( from lm_eval.utils import (
...@@ -35,7 +35,7 @@ def run_inference_one_model( ...@@ -35,7 +35,7 @@ def run_inference_one_model(
@register_model("vllm") @register_model("vllm")
class VLLM(LM): class VLLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048 _DEFAULT_MAX_LENGTH = 2048
def __init__( def __init__(
...@@ -47,6 +47,7 @@ class VLLM(LM): ...@@ -47,6 +47,7 @@ class VLLM(LM):
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
add_bos_token: Optional[bool] = False,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
quantization: Optional[str] = None, quantization: Optional[str] = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
...@@ -114,6 +115,7 @@ class VLLM(LM): ...@@ -114,6 +115,7 @@ class VLLM(LM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision, tokenizer_revision=tokenizer_revision,
) )
self.add_bos_token = add_bos_token
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
...@@ -147,10 +149,12 @@ class VLLM(LM): ...@@ -147,10 +149,12 @@ class VLLM(LM):
self, self,
string: str, string: str,
left_truncate_len=None, left_truncate_len=None,
add_special_tokens=False, add_special_tokens=None,
truncation=False, truncation=False,
): ):
""" """ """ """
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding = self.tokenizer.encode( encoding = self.tokenizer.encode(
string, add_special_tokens=add_special_tokens, truncation=truncation string, add_special_tokens=add_special_tokens, truncation=truncation
) )
...@@ -194,37 +198,6 @@ class VLLM(LM): ...@@ -194,37 +198,6 @@ class VLLM(LM):
) )
return outputs return outputs
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = [] loglikelihoods = []
...@@ -276,12 +249,16 @@ class VLLM(LM): ...@@ -276,12 +249,16 @@ class VLLM(LM):
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
re_ords = Collator(requests, _collate_gen, grouping=True) re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
chunks = re_ords.get_batched( chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(
total=len(requests),
disable=(self.rank != 0),
desc="Running generate_until requests",
)
# for each different set of kwargs, we execute all requests, by batch. # for each different set of kwargs, we execute all requests, by batch.
for chunk in chunks: for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
...@@ -356,7 +333,11 @@ class VLLM(LM): ...@@ -356,7 +333,11 @@ class VLLM(LM):
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
pbar = tqdm(total=len(requests), disable=disable_tqdm) pbar = tqdm(
total=len(requests),
disable=disable_tqdm,
desc="Running loglikelihood requests",
)
for chunk in chunks: for chunk in chunks:
inputs = [] inputs = []
ctxlens = [] ctxlens = []
......
import os
import ast import ast
import os
from typing import Dict from typing import Dict
from lm_eval import utils from lm_eval import utils
from lm_eval.utils import eval_logger from lm_eval.utils import eval_logger
# Prompt library. # Prompt library.
# Stores prompts in a dictionary indexed by 2 levels: # Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name. # prompt category name, and prompt name.
......
import os
import abc import abc
import collections import collections
import logging
import os
from functools import partial from functools import partial
from typing import List, Union, Dict from typing import Dict, List, Union
from lm_eval import utils from lm_eval import utils
from lm_eval.api.task import Task, ConfigurableTask from lm_eval.api.task import ConfigurableTask, Task
import logging
class TaskManager: class TaskManager:
...@@ -16,20 +14,14 @@ class TaskManager: ...@@ -16,20 +14,14 @@ class TaskManager:
and an optional directory if provided. and an optional directory if provided.
""" """
def __init__(
self,
verbosity="INFO",
include_path=None
) -> None:
def __init__(self, verbosity="INFO", include_path=None) -> None:
self.verbosity = verbosity self.verbosity = verbosity
self.include_path = include_path self.include_path = include_path
self.logger = utils.eval_logger self.logger = utils.eval_logger
self.logger.setLevel(getattr(logging, f"{verbosity}")) self.logger.setLevel(getattr(logging, f"{verbosity}"))
self._task_index = self.initialize_tasks( self._task_index = self.initialize_tasks(include_path=include_path)
include_path=include_path
)
self._all_tasks = sorted(list(self._task_index.keys())) self._all_tasks = sorted(list(self._task_index.keys()))
self.task_group_map = collections.defaultdict(list) self.task_group_map = collections.defaultdict(list)
...@@ -65,27 +57,29 @@ class TaskManager: ...@@ -65,27 +57,29 @@ class TaskManager:
return self._task_index return self._task_index
def match_tasks(self, task_list): def match_tasks(self, task_list):
return utils.pattern_match( return utils.pattern_match(task_list, self.all_tasks)
task_list, self.all_tasks
)
def _name_is_registered(self, name): def _name_is_registered(self, name):
if name in self.all_tasks: if name in self.all_tasks:
return True return True
return False return False
def _name_is_task(self, name): def _name_is_task(self, name) -> bool:
if self._name_is_registered(name) and ("task" in self.task_index[name]["type"]): if self._name_is_registered(name) and ("task" in self.task_index[name]["type"]):
return True return True
return False return False
def _name_is_group(self, name): def _name_is_group(self, name):
if self._name_is_registered(name) and (self.task_index[name]["type"] == "group"): if self._name_is_registered(name) and (
self.task_index[name]["type"] == "group"
):
return True return True
return False return False
def _name_is_python_task(self, name): def _name_is_python_task(self, name):
if self._name_is_registered(name) and (self.task_index[name]["type"] == "python_task"): if self._name_is_registered(name) and (
self.task_index[name]["type"] == "python_task"
):
return True return True
return False return False
...@@ -117,7 +111,7 @@ class TaskManager: ...@@ -117,7 +111,7 @@ class TaskManager:
return utils.load_yaml_config(yaml_path, mode="full") return utils.load_yaml_config(yaml_path, mode="full")
def _get_tasklist(self, name): def _get_tasklist(self, name):
assert self._name_is_task(name) == False assert self._name_is_task(name) is False
return self.task_index[name]["task"] return self.task_index[name]["task"]
def _process_alias(self, config, group=None): def _process_alias(self, config, group=None):
...@@ -130,12 +124,12 @@ class TaskManager: ...@@ -130,12 +124,12 @@ class TaskManager:
return config return config
def _load_individual_task_or_group( def _load_individual_task_or_group(
self, self,
name_or_config: Union[str, dict] = None, name_or_config: Union[str, dict] = None,
parent_name: str = None, parent_name: str = None,
update_config: dict = None, update_config: dict = None,
yaml_path: str = None, yaml_path: str = None,
) -> ConfigurableTask: ) -> ConfigurableTask:
def load_task(config, task, group=None, yaml_path=None): def load_task(config, task, group=None, yaml_path=None):
if "include" in config: if "include" in config:
assert yaml_path is not None assert yaml_path is not None
...@@ -174,7 +168,9 @@ class TaskManager: ...@@ -174,7 +168,9 @@ class TaskManager:
group_config = self._get_config(name_or_config) group_config = self._get_config(name_or_config)
if set(group_config.keys()) > set(["task", "group"]): if set(group_config.keys()) > set(["task", "group"]):
update_config = { update_config = {
k:v for k,v in group_config.items() if k not in ["task", "group"] k: v
for k, v in group_config.items()
if k not in ["task", "group"]
} }
yaml_path = self._get_yaml_path(group_name) yaml_path = self._get_yaml_path(group_name)
...@@ -183,9 +179,8 @@ class TaskManager: ...@@ -183,9 +179,8 @@ class TaskManager:
update_config.pop("group_alias") update_config.pop("group_alias")
if isinstance(name_or_config, dict): if isinstance(name_or_config, dict):
if update_config is not None: if update_config is not None:
name_or_config={ name_or_config = {
**name_or_config, **name_or_config,
**update_config, **update_config,
} }
...@@ -196,7 +191,9 @@ class TaskManager: ...@@ -196,7 +191,9 @@ class TaskManager:
# if self._name_is_task(name) is False: # if self._name_is_task(name) is False:
if self._name_is_group(name): if self._name_is_group(name):
group_name = name group_name = name
update_config = {k:v for k,v in name_or_config.items() if k != "task"} update_config = {
k: v for k, v in name_or_config.items() if k != "task"
}
subtask_list = self._get_tasklist(name) subtask_list = self._get_tasklist(name)
if subtask_list == -1: if subtask_list == -1:
subtask_list = self._get_config(name)["task"] subtask_list = self._get_config(name)["task"]
...@@ -207,36 +204,53 @@ class TaskManager: ...@@ -207,36 +204,53 @@ class TaskManager:
# Check if this is a duplicate. # Check if this is a duplicate.
if parent_name is not None: if parent_name is not None:
name_or_config["group"] = parent_name name_or_config["group"] = parent_name
num_duplicate = len(list(filter(lambda x: x.startswith(name), self.task_group_map[parent_name]))) num_duplicate = len(
list(
filter(
lambda x: x.startswith(name),
self.task_group_map[parent_name],
)
)
)
if num_duplicate > 0: if num_duplicate > 0:
name = f"{name}-{num_duplicate}" name = f"{name}-{num_duplicate}"
self.task_group_map[parent_name].append(name) self.task_group_map[parent_name].append(name)
task_config={ task_config = {
**base_task_config, **base_task_config,
**name_or_config, **name_or_config,
} }
else: else:
task_config = name_or_config task_config = name_or_config
return load_task(task_config, task=name, group=parent_name, yaml_path=yaml_path) return load_task(
task_config, task=name, group=parent_name, yaml_path=yaml_path
)
else: else:
group_name = name_or_config["group"] group_name = name_or_config["group"]
subtask_list = name_or_config["task"] subtask_list = name_or_config["task"]
# update_config = {k:v for k,v in name_or_config.items() if k != "task"}
if set(name_or_config.keys()) > set(["task", "group"]): if set(name_or_config.keys()) > set(["task", "group"]):
update_config = { update_config = {
k:v for k,v in name_or_config.items() if k not in ["task", "group"] k: v
for k, v in name_or_config.items()
if k not in ["task", "group"]
} }
all_subtasks = {} all_subtasks = {}
if (parent_name is not None): if parent_name is not None:
all_subtasks = {group_name: (parent_name, None)} all_subtasks = {group_name: (parent_name, None)}
fn = partial(self._load_individual_task_or_group, parent_name=group_name, update_config=update_config, yaml_path=yaml_path) fn = partial(
all_subtasks = {**all_subtasks, **dict(collections.ChainMap(*map(fn, subtask_list)))} self._load_individual_task_or_group,
parent_name=group_name,
update_config=update_config,
yaml_path=yaml_path,
)
all_subtasks = {
**all_subtasks,
**dict(collections.ChainMap(*map(fn, subtask_list))),
}
return all_subtasks return all_subtasks
def load_task_or_group(self, task_list: Union[str, list] = None) -> dict: def load_task_or_group(self, task_list: Union[str, list] = None) -> dict:
"""Loads a dictionary of task objects from a list """Loads a dictionary of task objects from a list
...@@ -250,12 +264,7 @@ class TaskManager: ...@@ -250,12 +264,7 @@ class TaskManager:
task_list = [task_list] task_list = [task_list]
all_loaded_tasks = dict( all_loaded_tasks = dict(
collections.ChainMap( collections.ChainMap(*map(self._load_individual_task_or_group, task_list))
*map(
self._load_individual_task_or_group,
task_list
)
)
) )
return all_loaded_tasks return all_loaded_tasks
...@@ -299,11 +308,11 @@ class TaskManager: ...@@ -299,11 +308,11 @@ class TaskManager:
# This is a group config # This is a group config
tasks_and_groups[config["group"]] = { tasks_and_groups[config["group"]] = {
"type": "group", "type": "group",
"task": -1, # This signals that "task": -1, # This signals that
# we don't need to know # we don't need to know
# the task list for indexing # the task list for indexing
# as it can be loaded # as it can be loaded
# when called. # when called.
"yaml_path": yaml_path, "yaml_path": yaml_path,
} }
...@@ -322,7 +331,7 @@ class TaskManager: ...@@ -322,7 +331,7 @@ class TaskManager:
tasks_and_groups[task] = { tasks_and_groups[task] = {
"type": "task", "type": "task",
"yaml_path": yaml_path, "yaml_path": yaml_path,
} }
if "group" in config: if "group" in config:
groups = config["group"] groups = config["group"]
...@@ -343,6 +352,7 @@ class TaskManager: ...@@ -343,6 +352,7 @@ class TaskManager:
return tasks_and_groups return tasks_and_groups
def include_path(task_dir): def include_path(task_dir):
logger = utils.eval_logger logger = utils.eval_logger
logger.setLevel(getattr(logging, "INFO")) logger.setLevel(getattr(logging, "INFO"))
...@@ -352,6 +362,7 @@ def include_path(task_dir): ...@@ -352,6 +362,7 @@ def include_path(task_dir):
) )
return 0 return 0
def initialize_tasks(verbosity="INFO"): def initialize_tasks(verbosity="INFO"):
logger = utils.eval_logger logger = utils.eval_logger
logger.setLevel(getattr(logging, f"{verbosity}")) logger.setLevel(getattr(logging, f"{verbosity}"))
...@@ -362,6 +373,7 @@ def initialize_tasks(verbosity="INFO"): ...@@ -362,6 +373,7 @@ def initialize_tasks(verbosity="INFO"):
) )
return 0 return 0
def get_task_name_from_config(task_config: Dict[str, str]) -> str: def get_task_name_from_config(task_config: Dict[str, str]) -> str:
if "task" in task_config: if "task" in task_config:
return task_config["task"] return task_config["task"]
...@@ -370,6 +382,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str: ...@@ -370,6 +382,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
else: else:
return "{dataset_path}".format(**task_config) return "{dataset_path}".format(**task_config)
def get_task_name_from_object(task_object): def get_task_name_from_object(task_object):
if hasattr(task_object, "config"): if hasattr(task_object, "config"):
return task_object._config["task"] return task_object._config["task"]
...@@ -382,7 +395,10 @@ def get_task_name_from_object(task_object): ...@@ -382,7 +395,10 @@ def get_task_name_from_object(task_object):
else type(task_object).__name__ else type(task_object).__name__
) )
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: TaskManager = None):
def get_task_dict(
task_name_list: List[Union[str, Dict, Task]], task_manager: TaskManager = None
):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object. """Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]] :param task_name_list: List[Union[str, Dict, Task]]
...@@ -409,7 +425,9 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: Ta ...@@ -409,7 +425,9 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: Ta
if task_manager is None: if task_manager is None:
task_manager = TaskManager() task_manager = TaskManager()
task_name_from_string_dict = task_manager.load_task_or_group(string_task_name_list) task_name_from_string_dict = task_manager.load_task_or_group(
string_task_name_list
)
for task_element in others_task_name_list: for task_element in others_task_name_list:
if isinstance(task_element, dict): if isinstance(task_element, dict):
...@@ -427,6 +445,7 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: Ta ...@@ -427,6 +445,7 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: Ta
assert set(task_name_from_string_dict.keys()).isdisjoint( assert set(task_name_from_string_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys()) set(task_name_from_object_dict.keys())
) )
return { return {
**task_name_from_string_dict, **task_name_from_string_dict,
**task_name_from_config_dict, **task_name_from_config_dict,
......
# ArabicMMLU
### Paper
ArabicMMLU: Measuring massive multitask language understanding in Arabic
This dataset has been translated from the original MMLU with the help of GPT-4.
The original data [MMLU](https://arxiv.org/pdf/2009.03300v3.pdf)
The translation has been done with AceGPT researchers [AceGPT](https://arxiv.org/abs/2309.12053)
ArabicMMLU is a comprehensive evaluation benchmark specifically designed to evaluate the knowledge and reasoning abilities of LLMs within the context of Arabic language and culture.
ArabicMMLU covers a wide range of subjects, comprising 57 topics that span from elementary to advanced professional levels.
Homepage: [AceGPT Homepage](https://github.com/FreedomIntelligence/AceGPT/tree/main/eval/benchmark_eval/benchmarks/MMLUArabic)
### Citation
### Groups and Tasks
#### Groups
- `ammlu`: All 57 subjects of the ArabicMMLU dataset, evaluated following the methodology in MMLU's original implementation.
#### Tasks
The following tasks evaluate subjects in the ArabicMMLU dataset using loglikelihood-based multiple-choice scoring:
- `ammlu_{subject_english}`
### Checklist
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [x] If yes, does the original paper provide a reference implementation?
* [x] Yes, original implementation contributed by author of the benchmark
If other tasks on this dataset are already supported:
* [x] Is the "Main" variant of this task clearly denoted?
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
group: ammlu
dataset_path: Hennara/ammlu
test_split: test
fewshot_split: dev
fewshot_config:
sampler: first_n
output_type: multiple_choice
doc_to_text: "{{Question.strip()}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nالجواب:"
doc_to_choice: ["A", "B", "C", "D"]
doc_to_target: "{{['A', 'B', 'C', 'D'].index(Answer)}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
metadata:
version: 0.0
"""
Take in a YAML, and output all other splits with this YAML
"""
import argparse
import os
import yaml
from tqdm import tqdm
SUBJECTS = {
"abstract_algebra": "ألعلوم وتقنية المعلومات و الرياضيات",
"anatomy": "ألعلوم وتقنية المعلومات و الرياضيات",
"astronomy": "ألعلوم وتقنية المعلومات و الرياضيات",
"business_ethics": "علوم أخرى",
"clinical_knowledge": "علوم أخرى",
"college_biology": "ألعلوم وتقنية المعلومات و الرياضيات",
"college_chemistry": "ألعلوم وتقنية المعلومات و الرياضيات",
"college_computer_science": "ألعلوم وتقنية المعلومات و الرياضيات",
"college_mathematics": "ألعلوم وتقنية المعلومات و الرياضيات",
"college_medicine": "علوم أخرى",
"college_physics": "ألعلوم وتقنية المعلومات و الرياضيات",
"computer_security": "ألعلوم وتقنية المعلومات و الرياضيات",
"conceptual_physics": "ألعلوم وتقنية المعلومات و الرياضيات",
"econometrics": "العلوم الإجتماعية",
"electrical_engineering": "ألعلوم وتقنية المعلومات و الرياضيات",
"elementary_mathematics": "ألعلوم وتقنية المعلومات و الرياضيات",
"formal_logic": "العلوم الانسانية",
"global_facts": "علوم أخرى",
"high_school_biology": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_chemistry": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_computer_science": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_european_history": "العلوم الانسانية",
"high_school_geography": "العلوم الإجتماعية",
"high_school_government_and_politics": "العلوم الإجتماعية",
"high_school_macroeconomics": "العلوم الإجتماعية",
"high_school_mathematics": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_microeconomics": "العلوم الإجتماعية",
"high_school_physics": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_psychology": "العلوم الإجتماعية",
"high_school_statistics": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_us_history": "العلوم الانسانية",
"high_school_world_history": "العلوم الانسانية",
"human_aging": "علوم أخرى",
"human_sexuality": "العلوم الإجتماعية",
"international_law": "العلوم الانسانية",
"jurisprudence": "العلوم الانسانية",
"logical_fallacies": "العلوم الانسانية",
"machine_learning": "ألعلوم وتقنية المعلومات و الرياضيات",
"management": "علوم أخرى",
"marketing": "علوم أخرى",
"medical_genetics": "علوم أخرى",
"miscellaneous": "علوم أخرى",
"moral_disputes": "العلوم الانسانية",
"moral_scenarios": "العلوم الانسانية",
"nutrition": "علوم أخرى",
"philosophy": "العلوم الانسانية",
"prehistory": "العلوم الانسانية",
"professional_accounting": "علوم أخرى",
"professional_law": "العلوم الانسانية",
"professional_medicine": "علوم أخرى",
"professional_psychology": "العلوم الإجتماعية",
"public_relations": "العلوم الإجتماعية",
"security_studies": "العلوم الإجتماعية",
"sociology": "العلوم الإجتماعية",
"us_foreign_policy": "العلوم الإجتماعية",
"virology": "علوم أخرى",
"world_religions": "العلوم الانسانية",
}
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_yaml_path", required=True)
parser.add_argument("--save_prefix_path", default="ammlu")
parser.add_argument("--cot_prompt_path", default=None)
parser.add_argument("--task_prefix", default="")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
base_yaml_name = os.path.split(args.base_yaml_path)[-1]
with open(args.base_yaml_path, encoding="utf-8") as f:
base_yaml = yaml.full_load(f)
if args.cot_prompt_path is not None:
import json
with open(args.cot_prompt_path, encoding="utf-8") as f:
cot_file = json.load(f)
for subject_eng, category in tqdm(SUBJECTS.items()):
if args.cot_prompt_path is not None:
description = cot_file[subject_eng]
else:
description = f"فم بعملية التقييم في مجال {category} \n\n"
yaml_dict = {
"include": base_yaml_name,
"task": f"ammlu_{args.task_prefix}_{subject_eng}"
if args.task_prefix != ""
else f"ammlu_{subject_eng}",
"dataset_name": subject_eng,
"description": description,
}
file_save_path = args.save_prefix_path + f"_{subject_eng}.yaml"
print(f"Saving yaml for subset {subject_eng} to {file_save_path}")
with open(file_save_path, "w", encoding="utf-8") as yaml_file:
yaml.dump(
yaml_dict,
yaml_file,
width=float("inf"),
allow_unicode=True,
default_style='"',
)
"dataset_name": "abstract_algebra"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_abstract_algebra"
"dataset_name": "anatomy"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_anatomy"
"dataset_name": "astronomy"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_astronomy"
"dataset_name": "business_ethics"
"description": "فم بعملية التقييم في مجال علوم أخرى \n\n"
"include": "_default_template_yaml"
"task": "ammlu_business_ethics"
"dataset_name": "clinical_knowledge"
"description": "فم بعملية التقييم في مجال علوم أخرى \n\n"
"include": "_default_template_yaml"
"task": "ammlu_clinical_knowledge"
"dataset_name": "college_biology"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_college_biology"
"dataset_name": "college_chemistry"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_college_chemistry"
"dataset_name": "college_computer_science"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_college_computer_science"
"dataset_name": "college_mathematics"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_college_mathematics"
"dataset_name": "college_medicine"
"description": "فم بعملية التقييم في مجال علوم أخرى \n\n"
"include": "_default_template_yaml"
"task": "ammlu_college_medicine"
"dataset_name": "college_physics"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_college_physics"
"dataset_name": "computer_security"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_computer_security"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment