Commit abd17276 authored by Baber's avatar Baber
Browse files

Merge branch 'smolrefact' into tasklist

# Conflicts:
#	lm_eval/__main__.py
#	lm_eval/api/group.py
#	lm_eval/api/task.py
#	lm_eval/evaluator_utils.py
#	lm_eval/tasks/__init__.py
#	lm_eval/utils.py
#	pyproject.toml
parents 00afd536 70314843
from __future__ import annotations
import logging
import os
from functools import cached_property
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any
from lm_eval.api.registry import register_model
from lm_eval.models.api_models import TemplateAPI
......@@ -26,9 +28,9 @@ class LocalCompletionsAPI(TemplateAPI):
def _create_payload(
self,
messages: Union[List[List[int]], List[dict], List[str], str],
messages: list[list[int]] | list[dict] | list[str] | str,
generate=False,
gen_kwargs: Optional[dict] = None,
gen_kwargs: dict | None = None,
seed: int = 1234,
eos=None,
**kwargs,
......@@ -50,24 +52,23 @@ class LocalCompletionsAPI(TemplateAPI):
"seed": seed,
**gen_kwargs,
}
else:
return {
"model": self.model,
"prompt": messages,
"temperature": 0,
"max_tokens": 1,
"logprobs": 1,
"seed": seed,
"echo": True,
}
return {
"model": self.model,
"prompt": messages,
"temperature": 0,
"max_tokens": 1,
"logprobs": 1,
"seed": seed,
"echo": True,
}
@staticmethod
def parse_logprobs(
outputs: Union[Dict, List[Dict]],
tokens: List[List[int]] = None,
ctxlens: List[int] = None,
outputs: dict | list[dict],
tokens: list[list[int]] = None,
ctxlens: list[int] = None,
**kwargs,
) -> List[Tuple[float, bool]]:
) -> list[tuple[float, bool]]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
......@@ -88,7 +89,7 @@ class LocalCompletionsAPI(TemplateAPI):
return res
@staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
def parse_generations(outputs: dict | list[dict], **kwargs) -> list[str]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
......@@ -130,9 +131,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
def _create_payload(
self,
messages: List[Dict],
messages: list[dict],
generate=False,
gen_kwargs: dict = None,
gen_kwargs: dict | None = None,
seed=1234,
eos=None,
**kwargs,
......@@ -160,7 +161,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
}
@staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
def parse_generations(outputs: dict | list[dict], **kwargs) -> list[str]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
......@@ -173,11 +174,11 @@ class LocalChatCompletion(LocalCompletionsAPI):
def tok_encode(
self,
string: Union[str, Any],
string: str | Any,
left_truncate_len=None,
add_special_tokens=None,
**kwargs,
) -> Union[List[str], List[int], Any]:
) -> list[str] | list[int] | Any:
return string
def loglikelihood(self, requests, **kwargs):
......@@ -219,7 +220,7 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
)
return super().loglikelihood(requests, **kwargs)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
def chat_template(self, chat_template: bool | str = False) -> str | None:
return ""
......@@ -261,7 +262,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
def _create_payload(
self,
messages: List[Dict],
messages: list[dict],
generate=False,
gen_kwargs: dict = None,
seed=1234,
......@@ -289,7 +290,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
"seed": seed,
**gen_kwargs,
}
if "o1" in self.model:
if "o1" in self.model or "5" in self.model:
output.pop("stop")
output["temperature"] = 1
elif "o3" in self.model:
......
......@@ -28,9 +28,8 @@ class OptimumLM(HFLM):
**kwargs,
) -> None:
if "backend" in kwargs:
# optimum currently only supports causal models
assert kwargs["backend"] == "causal", (
"Currently, only OVModelForCausalLM is supported."
assert kwargs["backend"] in ["causal", "seq2seq"], (
"Currently, only OVModelForCausalLM or OVModelForSeq2SeqLM are supported."
)
self.openvino_device = device
......@@ -54,7 +53,7 @@ class OptimumLM(HFLM):
"package `optimum` is not installed. Please install it via `pip install optimum[openvino]`"
)
else:
from optimum.intel.openvino import OVModelForCausalLM
from optimum.intel.openvino import OVModelForCausalLM, OVModelForSeq2SeqLM
model_kwargs = kwargs if kwargs else {}
if "ov_config" in model_kwargs:
......@@ -76,17 +75,14 @@ class OptimumLM(HFLM):
model_kwargs["ov_config"]["MODEL_DISTRIBUTION_POLICY"] = (
"PIPELINE_PARALLEL"
)
model_file = Path(pretrained) / "openvino_model.xml"
if model_file.exists():
export = False
else:
export = True
self._model = OVModelForCausalLM.from_pretrained(
model_cls = (
OVModelForCausalLM if self.backend == "causal" else OVModelForSeq2SeqLM
)
self._model = model_cls.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
export=export,
device=self.openvino_device.upper(),
**model_kwargs,
)
......@@ -216,7 +216,7 @@ class SGLangLM(TemplateLM):
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
re_ords = Collator(requests, _collate_gen, group_by=None)
chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
)
......@@ -232,36 +232,41 @@ class SGLangLM(TemplateLM):
context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
context_encoding_truncated = []
sampling_params = []
for x, gen_kwargs in zip(context_encoding, all_gen_kwargs):
# unpack our keyword arguments.
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
if len(x) > max_ctx_len:
context_encoding_truncated.append(x[-max_ctx_len:])
else:
context_encoding_truncated.append(x)
# create sampling params
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params.append(
kwargs | {"max_tokens": max_gen_toks, "stop": until}
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# perform batched generation
# cont is a list of dic. See here https://github.com/sgl-project/sglang/blob/0a6f18f068e4095fc228e798454e8496c9749214/python/sglang/srt/entrypoints/engine.py#L111 .
cont = self._model_generate(
requests=context_encoding,
requests=context_encoding_truncated,
generate=True,
max_tokens=max_gen_toks,
stop=until,
**kwargs,
sampling_params=sampling_params,
)
# cache generations
......@@ -284,28 +289,22 @@ class SGLangLM(TemplateLM):
self,
requests: List[List[int]] = None,
generate: bool = False,
max_tokens: int = None,
stop: Optional[List[str]] = None,
sampling_params: Union[List[Dict], Dict, None] = None,
return_logprob: bool = False,
top_logprobs_num: int = 1,
logprob_start_len: int = -1,
**kwargs,
):
# check sglang sampling parameters: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/sampling/sampling_params.py#L21 and https://docs.sglang.ai/references/sampling_params.html.
if generate:
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params = {
"max_new_tokens": max_tokens,
"stop": stop,
}
sampling_params.update(kwargs)
else:
sampling_params = {
"temperature": 0,
"max_new_tokens": 1,
}
sampling_params.update(kwargs)
if not generate:
sampling_params = sampling_params if sampling_params else {}
sampling_params.update(
{
"temperature": 0,
"max_new_tokens": 1,
}
)
if not isinstance(sampling_params, List):
sampling_params = [sampling_params] * len(requests)
# Refer to: https://docs.sglang.ai/backend/offline_engine_api.html
outputs = self.model.generate(
input_ids=requests,
......
from __future__ import annotations
import copy
import gc
import inspect
import logging
import os
from importlib.metadata import version
......@@ -8,7 +9,7 @@ from importlib.util import find_spec
from multiprocessing import Process, Queue
from queue import Empty
from time import sleep
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Literal
import jinja2
from more_itertools import distribute
......@@ -33,7 +34,7 @@ from lm_eval.utils import (
try:
import ray
from vllm import LLM, SamplingParams
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import get_open_port
......@@ -41,7 +42,7 @@ try:
if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template
except ModuleNotFoundError:
pass
print("njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd")
if TYPE_CHECKING:
pass
......@@ -51,7 +52,7 @@ eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker(
model_args: dict,
sampling_params: "SamplingParams",
sampling_params: list["SamplingParams"],
requests: list[list[int]],
lora_request: "LoRARequest",
result_queue: "Queue",
......@@ -79,7 +80,7 @@ def _vllm_mp_worker(
try:
llm = LLM(**model_args)
res = llm.generate(
prompt_token_ids=requests,
[TokensPrompt(prompt_token_ids=request) for request in requests],
sampling_params=sampling_params,
lora_request=lora_request,
)
......@@ -114,30 +115,30 @@ class VLLM(TemplateLM):
self,
pretrained: str,
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
revision: Optional[str] = None,
trust_remote_code: Optional[bool] = False,
tokenizer: Optional[str] = None,
revision: str | None = None,
trust_remote_code: bool | None = False,
tokenizer: str | None = None,
tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None,
add_bos_token: Optional[bool] = False,
prefix_token_id: Optional[int] = None,
tokenizer_revision: str | None = None,
add_bos_token: bool | None = False,
prefix_token_id: int | None = None,
tensor_parallel_size: int = 1,
quantization: Optional[str] = None,
quantization: str | None = None,
max_gen_toks: int = 256,
swap_space: int = 4,
batch_size: Union[str, int] = 1,
max_batch_size=None,
max_length: int = None,
max_model_len: int = None,
batch_size: str | int = 1,
max_batch_size: int | None = None,
max_length: int | None = None,
max_model_len: int | None = None,
seed: int = 1234,
gpu_memory_utilization: float = 0.9,
data_parallel_size: int = 1,
lora_local_path: str = None,
lora_local_path: str | None = None,
# VLLM: enable thinking tags in the prompt.
enable_thinking: bool = True,
chat_template_args: Optional[dict] = None,
# End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None,
think_end_token: str | None = None,
max_lora_rank: int = 16,
**kwargs,
):
......@@ -173,7 +174,7 @@ class VLLM(TemplateLM):
"swap_space": int(swap_space),
"quantization": quantization,
"seed": int(seed),
"enable_lora": True if lora_local_path else False,
"enable_lora": bool(lora_local_path),
"max_lora_rank": int(max_lora_rank),
}
self.model_args.update(kwargs)
......@@ -196,6 +197,12 @@ class VLLM(TemplateLM):
self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.")
if "gemma" in pretrained.lower():
add_bos_token = True
eval_logger.info(
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
)
from transformers import AutoConfig
self._config = AutoConfig.from_pretrained(
......@@ -214,11 +221,6 @@ class VLLM(TemplateLM):
"enable_thinking", enable_thinking
)
self.add_bos_token = add_bos_token
if "gemma" in pretrained.lower():
self.add_bos_token = True
eval_logger.info(
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
)
if parse_version(version("vllm")) >= parse_version("0.8.3"):
kwargs_resolve_hf_chat_template = {
......@@ -239,13 +241,6 @@ class VLLM(TemplateLM):
model_config = engine_args.create_model_config()
kwargs_resolve_hf_chat_template["model_config"] = model_config
# https://github.com/vllm-project/vllm/pull/18259
if (
"trsut_remote_code"
in inspect.signature(resolve_hf_chat_template).parameters
):
kwargs_resolve_hf_chat_template["trsut_remote_code"] = trust_remote_code
else:
kwargs_resolve_hf_chat_template["trust_remote_code"] = trust_remote_code
......@@ -307,7 +302,7 @@ class VLLM(TemplateLM):
return self._max_gen_toks
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
......@@ -344,14 +339,14 @@ class VLLM(TemplateLM):
def tok_encode(
self,
string: Union[str, List[str]],
string: str | list[str],
left_truncate_len: int = None,
add_special_tokens: bool = False,
truncation: bool = False,
) -> Union[List[int], List[List[int]]]:
) -> list[int] | list[list[int]]:
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer(
encoding: list[list[int]] | list[int] = self.tokenizer(
string,
add_special_tokens=add_special_tokens,
truncation=truncation,
......@@ -369,19 +364,16 @@ class VLLM(TemplateLM):
def _model_generate(
self,
requests: List[List[int]] = None,
requests: list[list[int]] = None,
generate: bool = False,
max_tokens: int = None,
stop: Optional[List[str]] = None,
**kwargs,
sampling_params: Union[List["SamplingParams"], "SamplingParams", None] = None,
):
if generate:
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else:
if not generate or sampling_params is None:
sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
)
if not isinstance(sampling_params, List):
sampling_params = [sampling_params] * len(requests)
if self.data_parallel_size > 1 and not self.V1:
# vLLM hangs if resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn
......@@ -389,13 +381,13 @@ class VLLM(TemplateLM):
@ray.remote
def run_inference_one_model(
model_args: dict,
sampling_params: SamplingParams,
requests: List[List[int]],
lora_request: LoRARequest,
sampling_params: list["SamplingParams"],
requests: list[list[int]],
lora_request: "LoRARequest",
):
llm = LLM(**model_args)
return llm.generate(
prompt_token_ids=requests,
[TokensPrompt(prompt_token_ids=request) for request in requests],
sampling_params=sampling_params,
lora_request=lora_request,
)
......@@ -403,9 +395,12 @@ class VLLM(TemplateLM):
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
sampling_params = [
list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
]
inputs = (
(self.model_args, sampling_params, req, self.lora_request)
for req in requests
(self.model_args, sp, req, self.lora_request)
for req, sp in zip(requests, sampling_params)
)
object_refs = [run_inference_one_model.remote(*x) for x in inputs]
results = ray.get(object_refs)
......@@ -420,16 +415,18 @@ class VLLM(TemplateLM):
dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port()
requests = (list(x) for x in distribute(self.data_parallel_size, requests))
sampling_params = (
list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
)
procs, resq = [], Queue()
# We use Process as it is non-daemonic
try:
for rank, req in enumerate(requests):
for rank, (sp, req) in enumerate(zip(requests, sampling_params)):
proc = Process(
target=_vllm_mp_worker,
args=(
self.model_args.copy(),
sampling_params,
sp,
req,
self.lora_request,
resq,
......@@ -459,7 +456,7 @@ class VLLM(TemplateLM):
if dead_procs:
raise RuntimeError(
f"Worker processes {dead_procs} died unexpectedly"
)
) from None
continue
results = [rank_res[i] for i in range(len(procs))]
......@@ -484,16 +481,16 @@ class VLLM(TemplateLM):
else:
outputs = self.model.generate(
prompt_token_ids=requests,
[TokensPrompt(prompt_token_ids=request) for request in requests],
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
use_tqdm=self.batch_size == "auto",
lora_request=self.lora_request,
)
return outputs
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[float]:
adaptive_batch_size = None
if self.batch_size == "auto":
adaptive_batch_size = len(requests)
......@@ -508,7 +505,7 @@ class VLLM(TemplateLM):
disable=(disable_tqdm or (self.rank != 0)),
)
):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
rolling_token_windows: list[tuple[list[int], list[int]]] = list(
map(
make_disjoint_window,
get_rolling_token_windows(
......@@ -561,13 +558,13 @@ class VLLM(TemplateLM):
return loglikelihoods
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]:
res = []
# batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding: List[List[int]] = self.tok_encode(
context_encoding: list[list[int]] = self.tok_encode(
context, add_special_tokens=self.add_bos_token
)
requests = [
......@@ -583,10 +580,11 @@ class VLLM(TemplateLM):
# - any OOMs will happen right away rather than near the end
return -len(_requests[0][1]), _requests[0][0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
re_ords = Collator(
requests,
_collate_gen,
group_by=None,
)
chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
)
......@@ -601,45 +599,48 @@ class VLLM(TemplateLM):
for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
all_lengths = [len(x) for x in context_encoding]
for length in all_lengths:
if length > max_ctx_len:
context_encoding_truncated = []
sampling_params = []
for x, gen_kwargs in zip(context_encoding, all_gen_kwargs):
# unpack our keyword arguments.
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
if len(x) > max_ctx_len:
eval_logger.warning(
f"Context length {length} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
f"Context length {len(x)} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
)
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
context_encoding_truncated.append(x[-max_ctx_len:])
else:
context_encoding_truncated.append(x)
# create sampling params
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params.append(
SamplingParams(max_tokens=max_gen_toks, stop=until, **kwargs)
)
# perform batched generation
cont = self._model_generate(
requests=context_encoding,
requests=context_encoding_truncated,
generate=True,
max_tokens=max_gen_toks,
stop=until,
**kwargs,
sampling_params=sampling_params,
)
# cache generations
for output, context in zip(cont, context):
for output, context_ in zip(cont, context):
generated_text: str = output.outputs[0].text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
generated_text = postprocess_generated_text(
......@@ -647,7 +648,7 @@ class VLLM(TemplateLM):
)
res.append(generated_text)
self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text
"generate_until", (context_, gen_kwargs), generated_text
)
pbar.update(1)
......@@ -657,9 +658,9 @@ class VLLM(TemplateLM):
def _loglikelihood_tokens(
self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
requests: list[tuple[tuple[str, str], list[int], list[int]]],
disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]:
) -> list[tuple[float, bool]]:
res = []
def _collate(x):
......@@ -680,7 +681,7 @@ class VLLM(TemplateLM):
for chunk in chunks:
inputs = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
for _cache_key, context_enc, continuation_enc in chunk:
if (
full_length := len(context_enc + continuation_enc)
) > self.max_length:
......@@ -718,7 +719,7 @@ class VLLM(TemplateLM):
return re_ord.get_original(res)
@staticmethod
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
def _parse_logprobs(tokens: list, outputs, ctxlen: int) -> tuple[float, bool]:
"""Process logprobs and tokens.
:param tokens: list
......
# Tasks
A list of supported tasks and task groupings can be viewed with `lm-eval --tasks list`.
A list of supported tasks and task groupings can be viewed with `lm-eval --tasks list`.
For more information, including a full list of task names and their precise meanings or sources, follow the links provided to the individual README.md files for each subfolder.
For more information, including a full list of task names and their precise meanings or sources, follow the links
provided to the individual README.md files for each subfolder.
| Task Family | Description | Language(s) |
|--------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------|
| [eq-bench_es](eq_bench/README.md) | Spanish version of EQ-Bench (EN). Task for evaluating emotional reasoning through dialogue-based prompts. [Hugging Face](https://huggingface.co/datasets/BSC-LT/EQ-bench_es) |Spanish **Human Translated** |
| [eq-bench_ca](eq_bench/README.md) | Catalan version of EQ-Bench (EN). Task for evaluating emotional reasoning through dialogue-based prompts. [Hugging Face](https://huggingface.co/datasets/BSC-LT/EQ-bench_ca)| Catalan **Human Translated** |
| [aclue](aclue/README.md) | Tasks focusing on ancient Chinese language understanding and cultural aspects. | Ancient Chinese |
| [acp_bench](acpbench/README.md) | Tasks evaluating the reasoning ability about Action, Change, and Planning | English |
| [acp_bench_hard](acpbench/README.md) | Tasks evaluating the reasoning ability about Action, Change, and Planning | English |
| [aexams](aexams/README.md) | Tasks in Arabic related to various academic exams covering a range of subjects. | Arabic |
| [agieval](agieval/README.md) | Tasks involving historical data or questions related to history and historical texts. | English, Chinese |
| [aime](aime/README.md) | High school math competition questions | English |
| [anli](anli/README.md) | Adversarial natural language inference tasks designed to test model robustness. | English |
| [arabic_leaderboard_complete](arabic_leaderboard_complete/README.md) | A full version of the tasks in the Open Arabic LLM Leaderboard, focusing on the evaluation of models that reflect the characteristics of Arabic language understanding and comprehension, culture, and heritage. Note that some of these tasks are machine-translated. | Arabic (Some MT) |
| [arabic_leaderboard_light](arabic_leaderboard_light/README.md) | A light version of the tasks in the Open Arabic LLM Leaderboard (i.e., 10% samples of the test set in the original benchmarks), focusing on the evaluation of models that reflect the characteristics of Arabic language understanding and comprehension, culture, and heritage. Note that some of these tasks are machine-translated. | Arabic (Some MT) |
| [arabicmmlu](arabicmmlu/README.md) | Localized Arabic version of MMLU with multiple-choice questions from 40 subjects. | Arabic |
| [ArabCulture](arab_culture/README.md) | Benchmark for evaluating modeles' commonsense cultural knowledge across different 13 different Arab Countries. | Arabic |
| [ArabCulture](arab_culture/README.md) | Benchmark for evaluating models' commonsense cultural knowledge across different 13 different Arab Countries. | Arabic |
| [AraDICE](aradice/README.md) | A collection of multiple tasks carefully designed to evaluate dialectal and cultural capabilities in large language models (LLMs). | Arabic |
| [arc](arc/README.md) | Tasks involving complex reasoning over a diverse set of questions. | English |
| [arithmetic](arithmetic/README.md) | Tasks involving numerical computations and arithmetic reasoning. | English |
| [asdiv](asdiv/README.md) | Tasks involving arithmetic and mathematical reasoning challenges. | English |
| [babi](babi/README.md) | Tasks designed as question and answering challenges based on simulated stories. | English |
| [babilong](babilong/README.md) | Tasks designed to test whether models can find and reason over facts in long contexts. | English |
| [basque_bench](basque_bench/README.md) | Collection of tasks in Basque encompassing various evaluation areas. | Basque |
| [basqueglue](basqueglue/README.md) | Tasks designed to evaluate language understanding in Basque language. | Basque |
| [bbh](bbh/README.md) | Tasks focused on deep semantic understanding through hypothesization and reasoning. | English, German |
......@@ -29,30 +33,36 @@
| [belebele](belebele/README.md) | Language understanding tasks in a variety of languages and scripts. | Multiple (122 languages) |
| benchmarks | General benchmarking tasks that test a wide range of language understanding capabilities. | |
| [bertaqa](bertaqa/README.md) | Local Basque cultural trivia QA tests in English and Basque languages. | English, Basque, Basque (MT) |
| [bhs](bhs/README.md) | Grammatical knowledge evaluation for low-resource langauges. | Basque, Hindi, Swahili |
| [bigbench](bigbench/README.md) | Broad tasks from the BIG-bench benchmark designed to push the boundaries of large models. | Multiple |
| [blimp](blimp/README.md) | Tasks testing grammatical phenomena to evaluate language model's linguistic capabilities. | English |
| [blimp_nl](blimp_nl/README.md) | A benchmark evaluating language models' grammatical capabilities in Dutch based on comparing the probabilities of minimal pairs of grammatical and ungrammatical sentences. | Dutch |
| [c4](c4/README.md) | Tasks based on a colossal, cleaned version of Common Crawl's web crawl corpus to assess models' language modeling capabilities. | English |
| [cabbq](cabbq/README.md) | Adaptation of the [BBQ](bbq/README.md) benchmark to the Catalan language and stereotypes prevalent in Spain. | Catalan |
| [careqa](careqa/README.md) | Multiple choice and open-ended medical question answering based on the Spanish Specialised Healthcare Training (MIR) exams. | English, Spanish |
| [catalan_bench](catalan_bench/README.md) | Collection of tasks in Catalan encompassing various evaluation areas. | Catalan |
| [ceval](ceval/README.md) | Tasks that evaluate language understanding and reasoning in an educational context. | Chinese |
| [cmmlu](cmmlu/README.md) | Multi-subject multiple choice question tasks for comprehensive academic assessment. | Chinese |
| code_x_glue | Tasks that involve understanding and generating code across multiple programming languages. | Go, Java, JS, PHP, Python, Ruby |
| [commonsense_qa](commonsense_qa/README.md) | CommonsenseQA, a multiple-choice QA dataset for measuring commonsense knowledge. | English |
| [copal_id](copal_id/README.md) United States | Indonesian causal commonsense reasoning dataset that captures local nuances. | Indonesian |
| [copal_id](copal_id/README.md) United States | Indonesian causal commonsense reasoning dataset that captures local nuances. | Indonesian |
| [coqa](coqa/README.md) | Conversational question answering tasks to test dialog understanding. | English |
| [crows_pairs](crows_pairs/README.md) | Tasks designed to test model biases in various sociodemographic groups. | English, French |
| [click](click/README.md) | A benchmark dataset of Cultural and Linguistic Intelligence in Korean (CLIcK), comprising 1,995 QA pairs sourced from official Korean exams and textbooks to test Korean cultural and linguistic knowledge. | Korean |
| csatqa | Tasks related to SAT and other standardized testing questions for academic assessment. | Korean |
| [darija_bench](darija_bench/README.md) | Traditional NLP tasks (Translation, Summariation, etc..) for Moroccan Darija | Moroccan Darija (some MT) |
| [darija_bench](darija_bench/README.md) | Traditional NLP tasks (Translation, Summarization, etc..) for Moroccan Darija | Moroccan Darija (some MT) |
| [darijahellaswag](darijahellaswag/README.md) | Moroccan Darija version of HellaSwag. | Moroccan Darija (MT) |
| [darijammlu](darijammlu/README.md) | Multiple-choice QA in Moroccan Darija (an Arabic dialect). | Moroccan Darija (MT) |
| [discrim_eval](discrim_eval/README.md) | Prompts for binary decisions covering 70 scenarios to evaluate demographic bias. | English |
| [drop](drop/README.md) | Tasks requiring numerical reasoning, reading comprehension, and question answering. | English |
| [egyhellaswag](egyhellaswag/README.md) | Egyptian Arabic (Masri) version of HellaSwag. | Egyptian Arabic (MT) |
| [egymmlu](egymmlu/README.md) | Multiple-choice QA in Egyptian Arabic. | Egyptian Arabic (MT) |
| [eq_bench](eq_bench/README.md) | Tasks focused on equality and ethics in question answering and decision-making. | English |
| [esbbq](esbbq/README.md) | Adaptation of the [BBQ](bbq/README.md) benchmark to the Spanish language and stereotypes prevalent in Spain. | Spanish |
| [eus_exams](eus_exams/README.md) | Tasks based on various professional and academic exams in the Basque language. | Basque |
| [eus_proficiency](eus_proficiency/README.md) | Tasks designed to test proficiency in the Basque language across various topics. | Basque |
| [eus_reading](eus_reading/README.md) | Reading comprehension tasks specifically designed for the Basque language. | Basque |
| [eus_trivia](eus_trivia/README.md) | Trivia and knowledge testing tasks in the Basque language. | Basque |
| [eus_trivia](eus_trivia/README.md) | Trivia atypicnd knowledge testing tasks in the Basque language. | Basque |
| [evalita_LLM](evalita_llm/README.md) | A native Italian benchmark with diverse tasks formats and multiple prompts. | Italian |
| [fda](fda/README.md) | Tasks for extracting key-value pairs from FDA documents to test information extraction. | English |
| [fld](fld/README.md) | Tasks involving free-form and directed dialogue understanding. | English |
......@@ -71,13 +81,15 @@
| [histoires_morales](histoires_morales/README.md) | A dataset of structured narratives that describe normative and norm-divergent actions taken by individuals to accomplish certain intentions in concrete situations. | French (Some MT) |
| [hrm8k](hrm8k/README.md) | A challenging bilingual math reasoning benchmark for Korean and English. | Korean (Some MT), English (Some MT) |
| [humaneval](humaneval/README.md) | Code generation task that measure functional correctness for synthesizing programs from docstrings. | Python |
| [humaneval_infilling](humaneval_infilling/README.md) | Code generation task that measure fill-in-the-middle capability for synthesizing programs from docstrings. | Python |
| [icelandic_winogrande](icelandic_winogrande/README.md) | Manually translated and localized version of the [WinoGrande](winogrande/README.md) commonsense reasoning benchmark for Icelandic. | Icelandic |
| [ifeval](ifeval/README.md) | Interactive fiction evaluation tasks for narrative understanding and reasoning. | English |
| [inverse_scaling](inverse_scaling/README.md) | Multiple-choice tasks from the Inverse Scaling Prize, designed to find settings where larger language models perform worse. | English |
| [japanese_leaderboard](japanese_leaderboard/README.md) | Japanese language understanding tasks to benchmark model performance on various linguistic aspects. | Japanese |
| [jsonschema_bench](jsonschema_bench/README.md) | Evaluate the ability of LLMs to generate JSON objects that conform to a given JSON schema, including API, configuration files, and other structured data formats. | JSON |
| [kbl](kbl/README.md) | Korean Benchmark for Legal Language Understanding. | Korean |
| [kmmlu](kmmlu/README.md) | Knowledge-based multi-subject multiple choice questions for academic evaluation. | Korean |
| [kobest](kobest/README.md) | A collection of tasks designed to evaluate understanding in Korean language. | Korean |
| [kobest](kobest/README.md) | A collection of tasks designed to evaluate understanding in Korean language{Fecha: language. | Korean |
| [kormedmcqa](kormedmcqa/README.md) | Medical question answering tasks in Korean to test specialized domain knowledge. | Korean |
| [lambada](lambada/README.md) | Tasks designed to predict the endings of text passages, testing language prediction skills. | English |
| [lambada_cloze](lambada_cloze/README.md) | Cloze-style LAMBADA dataset. | English |
......@@ -85,9 +97,12 @@
| [lambada_multilingual_stablelm](lambada_multilingual_stablelm/README.md) | Multilingual LAMBADA dataset. Users should prefer evaluating on this version of the multilingual dataset instead of on `lambada_multilingual`. | German, English, Spanish, French, Italian, Dutch, Portuguese |
| [leaderboard](leaderboard/README.md) | Task group used by Hugging Face's [Open LLM Leaderboard v2](https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard). Those tasks are static and will not change through time | English |
| [lingoly](lingoly/README.md) | Challenging logical reasoning benchmark in low-resource languages with controls for memorization | English, Multilingual |
| [libra](libra/README.md) | Evaluates long-context understanding in Russian across four complexity levels | Russian (MT) |
| [llama3](llama3/README.md) | Evals reproducing those provided by the LLAMA team in the Hugging Face repo (instruct) | English, Multilingual |
| [libra](libra/README.md) | Evaluates long-context understanding in Russian across four complexity levels | Russian (MT) |
| [lm_syneval](lm_syneval/README.md) | Evaluates the syntactic capabilities of language models. | English |
| [logiqa](logiqa/README.md) | Logical reasoning tasks requiring advanced inference and deduction. | English, Chinese |
| [logiqa2](logiqa2/README.md) | Large-scale logical reasoning dataset adapted from the Chinese Civil Service Examination. | English, Chinese |
| [longbench](longbench/README.md) | LongBench evaluates language models' ability to understand lengthy texts across multiple tasks and languages. | English, Chinese |
| [mastermind](mastermind/README.md) | Reasoning benchmark based on the board game of Mastermind. | English |
| [mathqa](mathqa/README.md) | Question answering tasks involving mathematical reasoning and problem-solving. | English |
| [mbpp](mbpp/README.md) | A benchmark designed to measure the ability to synthesize short Python programs from natural language descriptions. | Python |
......@@ -105,9 +120,11 @@
| [minerva_math](minerva_math/README.md) | Mathematics-focused tasks requiring numerical reasoning and problem-solving skills. | English |
| [mlqa](mlqa/README.md) | MultiLingual Question Answering benchmark dataset for evaluating cross-lingual question answering performance. | English, Arabic, German, Spanish, Hindi, Vietnamese, Simplified Chinese |
| [mmlu](mmlu/README.md) | Massive Multitask Language Understanding benchmark for broad domain language evaluation. Several variants are supported. | English |
| [mmlu_redux](mmlu-redux/README.md) | Refined Massive Multitask Language Understanding benchmark for broad domain evaluation with improved data quality. | English |
| [mmlu_redux](mmlu-redux-spanish/README.md) | Refined Massive Multitask Language Understanding benchmark for broad domain evaluation with improved data quality. | Spanish |
| [mmlu_pro](mmlu_pro/README.md) | A refined set of MMLU, integrating more challenging, reasoning-focused questions and expanding the choice set from four to ten options. | English |
| [mmlu-pro-plus](mmlu-pro-plus/README.md) | A new test set for evaluating shortcut learning and higher-order reasoning of LLMs. | English |
| [mmlu_prox](mmlu_prox/README.md) | A multilingual benchmark that extends MMLU-Pro to multiple typologically diverse languages with human validation. | English, Japanese, Chinese, Korean, French, German, Spanish, Portuguese, Swahili, Thai, Arabic, Hindi, Bengali |
| [mmlu_prox](mmlu_prox/README.md) | A multilingual benchmark that extends MMLU-Pro to multiple typologically diverse languages with human validation. | English, Japanese, Chinese, Korean, French, German, Spanish, Portuguese, Zulu, Swahili, Wolof, Yoruba, Thai, Arabic, Hindi, Bengali, Serbian, Hungarian, Vietnamese, Czech, Marathi, Afrikaans, Nepali, Telugu, Urdu, Russian, Indonesian, Italian, Ukrainian|
| [mmlusr](mmlusr/README.md) | Variation of MMLU designed to be more rigorous. | English |
| model_written_evals | Evaluation tasks auto-generated for evaluating a collection of AI Safety concerns. | |
| [moral_stories](moral_stories/README.md) | A crowd-sourced dataset of structured narratives that describe normative and norm-divergent actions taken by individuals to accomplish certain intentions in concrete situations. | English |
......@@ -156,6 +173,7 @@
| [truthfulqa](truthfulqa/README.md) | A QA task aimed at evaluating the truthfulness and factual accuracy of model responses. | English |
| [truthfulqa-multi](truthfulqa-multi/README.md) | Is a multilingual version of TruthfulQA, a QA task aimed at evaluating the truthfulness and factual accuracy of model responses. | English, Spanish, Catalan, Basque, Galician |
| [turkishmmlu](turkishmmlu/README.md) | A multiple-choice QA test modeled after MMLU, written in Turkish based on Turkish high-school level exams. | Turkish |
| [turblimp_core](turblimp/README.md) | A benchmark evaluating language models' grammatical capabilities in Turkish based on comparing the probabilities of minimal pairs of grammatical and ungrammatical sentences. | Turkish |
| [unitxt](unitxt/README.md) | A number of tasks implemented using the unitxt library for flexible, shareable, and reusable data preparation and evaluation for generative AI. | English |
| [unscramble](unscramble/README.md) | Tasks involving the rearrangement of scrambled sentences to test syntactic understanding. | English |
| [webqs](webqs/README.md) | Web-based question answering tasks designed to evaluate internet search and retrieval. | English |
......@@ -171,9 +189,11 @@
| [xquad](xquad/README.md) | Cross-lingual Question Answering Dataset in multiple languages. | Arabic, German, Greek, English, Spanish, Hindi, Romanian, Russian, Thai, Turkish, Vietnamese, Chinese |
| [xstorycloze](xstorycloze/README.md) | Cross-lingual narrative understanding tasks to predict story endings in multiple languages. | Russian, Simplified Chinese, Spanish, Arabic, Hindi, Indonesian, Telugu, Swahili, Basque, Burmese |
| [xwinograd](xwinograd/README.md) | Cross-lingual Winograd schema tasks for coreference resolution in multiple languages. | English, French, Japanese, Portuguese, Russian, Chinese |
| [zhoblimp](zhoblimp/README.md) | A benchmark evaluating language models' grammatical capabilities in Chinese based on comparing the probabilities of minimal pairs of grammatical and ungrammatical sentences. | Chinese |
## Multimodal Tasks
| Task Family | Description | Modality |
|------------------------------|---------------------------------------------------------------------------------------------------------|-------------|
| ---------------------------- | ------------------------------------------------------------------------------------------------------- | ----------- |
| [chartqa](chartqa/README.md) | A benchmark for question answering about charts that requires both visual and logical reasoning. | Image, Text |
| [mmmu](mmmu/README.md) | Evaluate multimodal models on massive multi-discipline tasks demanding college-level subject knowledge. | Image, Text |
......@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None
class ACPGrammarParser(object):
class ACPGrammarParser:
def __init__(self, task) -> None:
self.task = task
with open(GRAMMAR_FILE) as f:
......@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower())
d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static())
......
......@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None
class ACPGrammarParser(object):
class ACPGrammarParser:
def __init__(self, task) -> None:
self.task = task
with open(GRAMMAR_FILE) as f:
......@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower())
d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static())
......
......@@ -2,7 +2,6 @@ tag:
- adr_tasks
- adr_prompt_1
dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target
output_type: generate_until
fewshot_split: dev
......
......@@ -2,7 +2,6 @@ tag:
- adr_tasks
- adr_prompt_2
dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target
output_type: generate_until
fewshot_split: dev
......
......@@ -2,7 +2,6 @@ tag:
- adr_tasks
- adr_prompt_3
dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target
output_type: generate_until
fewshot_split: dev
......
......@@ -2,7 +2,6 @@ tag:
- adr_tasks
- adr_prompt_4
dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target
output_type: generate_until
fewshot_split: dev
......
......@@ -2,7 +2,6 @@ tag:
- adr_tasks
- adr_prompt_5
dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target
output_type: generate_until
fewshot_split: dev
......
......@@ -4,7 +4,6 @@ tag:
task: null
dataset_path: masakhane/afrisenti
dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice
validation_split: validation
test_split: test
......
......@@ -3,7 +3,6 @@ tag:
- afrisent_prompt_2
dataset_path: masakhane/afrisenti
dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice
validation_split: validation
test_split: test
......
......@@ -3,7 +3,6 @@ tag:
- afrisenti_prompt_3
dataset_path: masakhane/afrisenti
dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice
validation_split: validation
test_split: test
......
......@@ -3,7 +3,6 @@ tag:
- afrisenti_prompt_4
dataset_path: masakhane/afrisenti
dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice
validation_split: validation
test_split: test
......
......@@ -3,7 +3,6 @@ tag:
- afrisenti_prompt_5
dataset_path: masakhane/afrisenti
dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice
validation_split: validation
test_split: test
......
......@@ -73,3 +73,5 @@ HomePage: https://github.com/masakhane-io/masakhane-pos
abstract = "In this paper, we present AfricaPOS, the largest part-of-speech (POS) dataset for 20 typologically diverse African languages. We discuss the challenges in annotating POS for these languages using the universal dependencies (UD) guidelines. We conducted extensive POS baseline experiments using both conditional random field and several multilingual pre-trained language models. We applied various cross-lingual transfer models trained with data available in the UD. Evaluating on the AfricaPOS dataset, we show that choosing the best transfer language(s) in both single-source and multi-source setups greatly improves the POS tagging performance of the target languages, in particular when combined with parameter-fine-tuning methods. Crucially, transferring knowledge from a language that matches the language family and morphosyntactic properties seems to be more effective for POS tagging in unseen languages."
}
```
## Changelog
- 2025-07-21: Refactored. Scores should not be affected.
......@@ -14,19 +14,18 @@ validation_split: validation
test_split: test
fewshot_split: train
doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results
should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
filter_list:
- filter:
- function: regex_pos
- function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract
metric_list:
- metric: acc
aggregation: !function utils.acc_score
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata:
version: 1.0
from itertools import chain
import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc):
pos_tag_map = {
......@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items):
unzipped_list = list(zip(*items))
def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1]
filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
return filtered_resps
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list
accuracy = accuracy_score(gold, pred)
accuracy_scores.append(accuracy)
def process_results(doc: dict[str, Any], results: list[list[str]]):
golds, preds = doc_to_target(doc), results[0]
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = (
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
return {"acc": accuracy}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment