Commit abd17276 authored by Baber's avatar Baber
Browse files

Merge branch 'smolrefact' into tasklist

# Conflicts:
#	lm_eval/__main__.py
#	lm_eval/api/group.py
#	lm_eval/api/task.py
#	lm_eval/evaluator_utils.py
#	lm_eval/tasks/__init__.py
#	lm_eval/utils.py
#	pyproject.toml
parents 00afd536 70314843
from __future__ import annotations
import logging import logging
import os import os
from functools import cached_property from functools import cached_property
from operator import itemgetter from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.api_models import TemplateAPI from lm_eval.models.api_models import TemplateAPI
...@@ -26,9 +28,9 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -26,9 +28,9 @@ class LocalCompletionsAPI(TemplateAPI):
def _create_payload( def _create_payload(
self, self,
messages: Union[List[List[int]], List[dict], List[str], str], messages: list[list[int]] | list[dict] | list[str] | str,
generate=False, generate=False,
gen_kwargs: Optional[dict] = None, gen_kwargs: dict | None = None,
seed: int = 1234, seed: int = 1234,
eos=None, eos=None,
**kwargs, **kwargs,
...@@ -50,24 +52,23 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -50,24 +52,23 @@ class LocalCompletionsAPI(TemplateAPI):
"seed": seed, "seed": seed,
**gen_kwargs, **gen_kwargs,
} }
else: return {
return { "model": self.model,
"model": self.model, "prompt": messages,
"prompt": messages, "temperature": 0,
"temperature": 0, "max_tokens": 1,
"max_tokens": 1, "logprobs": 1,
"logprobs": 1, "seed": seed,
"seed": seed, "echo": True,
"echo": True, }
}
@staticmethod @staticmethod
def parse_logprobs( def parse_logprobs(
outputs: Union[Dict, List[Dict]], outputs: dict | list[dict],
tokens: List[List[int]] = None, tokens: list[list[int]] = None,
ctxlens: List[int] = None, ctxlens: list[int] = None,
**kwargs, **kwargs,
) -> List[Tuple[float, bool]]: ) -> list[tuple[float, bool]]:
res = [] res = []
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
...@@ -88,7 +89,7 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -88,7 +89,7 @@ class LocalCompletionsAPI(TemplateAPI):
return res return res
@staticmethod @staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]: def parse_generations(outputs: dict | list[dict], **kwargs) -> list[str]:
res = [] res = []
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
...@@ -130,9 +131,9 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -130,9 +131,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
def _create_payload( def _create_payload(
self, self,
messages: List[Dict], messages: list[dict],
generate=False, generate=False,
gen_kwargs: dict = None, gen_kwargs: dict | None = None,
seed=1234, seed=1234,
eos=None, eos=None,
**kwargs, **kwargs,
...@@ -160,7 +161,7 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -160,7 +161,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
} }
@staticmethod @staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]: def parse_generations(outputs: dict | list[dict], **kwargs) -> list[str]:
res = [] res = []
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
...@@ -173,11 +174,11 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -173,11 +174,11 @@ class LocalChatCompletion(LocalCompletionsAPI):
def tok_encode( def tok_encode(
self, self,
string: Union[str, Any], string: str | Any,
left_truncate_len=None, left_truncate_len=None,
add_special_tokens=None, add_special_tokens=None,
**kwargs, **kwargs,
) -> Union[List[str], List[int], Any]: ) -> list[str] | list[int] | Any:
return string return string
def loglikelihood(self, requests, **kwargs): def loglikelihood(self, requests, **kwargs):
...@@ -219,7 +220,7 @@ class OpenAICompletionsAPI(LocalCompletionsAPI): ...@@ -219,7 +220,7 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
) )
return super().loglikelihood(requests, **kwargs) return super().loglikelihood(requests, **kwargs)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: bool | str = False) -> str | None:
return "" return ""
...@@ -261,7 +262,7 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -261,7 +262,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
def _create_payload( def _create_payload(
self, self,
messages: List[Dict], messages: list[dict],
generate=False, generate=False,
gen_kwargs: dict = None, gen_kwargs: dict = None,
seed=1234, seed=1234,
...@@ -289,7 +290,7 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -289,7 +290,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
"seed": seed, "seed": seed,
**gen_kwargs, **gen_kwargs,
} }
if "o1" in self.model: if "o1" in self.model or "5" in self.model:
output.pop("stop") output.pop("stop")
output["temperature"] = 1 output["temperature"] = 1
elif "o3" in self.model: elif "o3" in self.model:
......
...@@ -28,9 +28,8 @@ class OptimumLM(HFLM): ...@@ -28,9 +28,8 @@ class OptimumLM(HFLM):
**kwargs, **kwargs,
) -> None: ) -> None:
if "backend" in kwargs: if "backend" in kwargs:
# optimum currently only supports causal models assert kwargs["backend"] in ["causal", "seq2seq"], (
assert kwargs["backend"] == "causal", ( "Currently, only OVModelForCausalLM or OVModelForSeq2SeqLM are supported."
"Currently, only OVModelForCausalLM is supported."
) )
self.openvino_device = device self.openvino_device = device
...@@ -54,7 +53,7 @@ class OptimumLM(HFLM): ...@@ -54,7 +53,7 @@ class OptimumLM(HFLM):
"package `optimum` is not installed. Please install it via `pip install optimum[openvino]`" "package `optimum` is not installed. Please install it via `pip install optimum[openvino]`"
) )
else: else:
from optimum.intel.openvino import OVModelForCausalLM from optimum.intel.openvino import OVModelForCausalLM, OVModelForSeq2SeqLM
model_kwargs = kwargs if kwargs else {} model_kwargs = kwargs if kwargs else {}
if "ov_config" in model_kwargs: if "ov_config" in model_kwargs:
...@@ -76,17 +75,14 @@ class OptimumLM(HFLM): ...@@ -76,17 +75,14 @@ class OptimumLM(HFLM):
model_kwargs["ov_config"]["MODEL_DISTRIBUTION_POLICY"] = ( model_kwargs["ov_config"]["MODEL_DISTRIBUTION_POLICY"] = (
"PIPELINE_PARALLEL" "PIPELINE_PARALLEL"
) )
model_file = Path(pretrained) / "openvino_model.xml"
if model_file.exists():
export = False
else:
export = True
self._model = OVModelForCausalLM.from_pretrained( model_cls = (
OVModelForCausalLM if self.backend == "causal" else OVModelForSeq2SeqLM
)
self._model = model_cls.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
export=export,
device=self.openvino_device.upper(), device=self.openvino_device.upper(),
**model_kwargs, **model_kwargs,
) )
...@@ -216,7 +216,7 @@ class SGLangLM(TemplateLM): ...@@ -216,7 +216,7 @@ class SGLangLM(TemplateLM):
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs") re_ords = Collator(requests, _collate_gen, group_by=None)
chunks = re_ords.get_batched( chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
...@@ -232,36 +232,41 @@ class SGLangLM(TemplateLM): ...@@ -232,36 +232,41 @@ class SGLangLM(TemplateLM):
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding) context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same context_encoding_truncated = []
# this is safe to assume because the `grouper` object ensures it. sampling_params = []
gen_kwargs = all_gen_kwargs[0] for x, gen_kwargs in zip(context_encoding, all_gen_kwargs):
# unpack our keyword arguments. # unpack our keyword arguments.
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences # add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
if len(x) > max_ctx_len:
context_encoding_truncated.append(x[-max_ctx_len:])
else:
context_encoding_truncated.append(x)
# create sampling params
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params.append(
kwargs | {"max_tokens": max_gen_toks, "stop": until}
) )
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# perform batched generation # perform batched generation
# cont is a list of dic. See here https://github.com/sgl-project/sglang/blob/0a6f18f068e4095fc228e798454e8496c9749214/python/sglang/srt/entrypoints/engine.py#L111 . # cont is a list of dic. See here https://github.com/sgl-project/sglang/blob/0a6f18f068e4095fc228e798454e8496c9749214/python/sglang/srt/entrypoints/engine.py#L111 .
cont = self._model_generate( cont = self._model_generate(
requests=context_encoding, requests=context_encoding_truncated,
generate=True, generate=True,
max_tokens=max_gen_toks, sampling_params=sampling_params,
stop=until,
**kwargs,
) )
# cache generations # cache generations
...@@ -284,28 +289,22 @@ class SGLangLM(TemplateLM): ...@@ -284,28 +289,22 @@ class SGLangLM(TemplateLM):
self, self,
requests: List[List[int]] = None, requests: List[List[int]] = None,
generate: bool = False, generate: bool = False,
max_tokens: int = None, sampling_params: Union[List[Dict], Dict, None] = None,
stop: Optional[List[str]] = None,
return_logprob: bool = False, return_logprob: bool = False,
top_logprobs_num: int = 1, top_logprobs_num: int = 1,
logprob_start_len: int = -1, logprob_start_len: int = -1,
**kwargs,
): ):
# check sglang sampling parameters: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/sampling/sampling_params.py#L21 and https://docs.sglang.ai/references/sampling_params.html. # check sglang sampling parameters: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/sampling/sampling_params.py#L21 and https://docs.sglang.ai/references/sampling_params.html.
if generate: if not generate:
kwargs = self.modify_gen_kwargs(kwargs) sampling_params = sampling_params if sampling_params else {}
sampling_params = { sampling_params.update(
"max_new_tokens": max_tokens, {
"stop": stop, "temperature": 0,
} "max_new_tokens": 1,
sampling_params.update(kwargs) }
else: )
sampling_params = { if not isinstance(sampling_params, List):
"temperature": 0, sampling_params = [sampling_params] * len(requests)
"max_new_tokens": 1,
}
sampling_params.update(kwargs)
# Refer to: https://docs.sglang.ai/backend/offline_engine_api.html # Refer to: https://docs.sglang.ai/backend/offline_engine_api.html
outputs = self.model.generate( outputs = self.model.generate(
input_ids=requests, input_ids=requests,
......
from __future__ import annotations
import copy import copy
import gc import gc
import inspect
import logging import logging
import os import os
from importlib.metadata import version from importlib.metadata import version
...@@ -8,7 +9,7 @@ from importlib.util import find_spec ...@@ -8,7 +9,7 @@ from importlib.util import find_spec
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
from queue import Empty from queue import Empty
from time import sleep from time import sleep
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Literal
import jinja2 import jinja2
from more_itertools import distribute from more_itertools import distribute
...@@ -33,7 +34,7 @@ from lm_eval.utils import ( ...@@ -33,7 +34,7 @@ from lm_eval.utils import (
try: try:
import ray import ray
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams, TokensPrompt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import get_open_port from vllm.utils import get_open_port
...@@ -41,7 +42,7 @@ try: ...@@ -41,7 +42,7 @@ try:
if parse_version(version("vllm")) >= parse_version("0.8.3"): if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template from vllm.entrypoints.chat_utils import resolve_hf_chat_template
except ModuleNotFoundError: except ModuleNotFoundError:
pass print("njklsfnljnlsjnjlksnljnfvljnflsdnlksfnlkvnlksfvnlsfd")
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
...@@ -51,7 +52,7 @@ eval_logger = logging.getLogger(__name__) ...@@ -51,7 +52,7 @@ eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker( def _vllm_mp_worker(
model_args: dict, model_args: dict,
sampling_params: "SamplingParams", sampling_params: list["SamplingParams"],
requests: list[list[int]], requests: list[list[int]],
lora_request: "LoRARequest", lora_request: "LoRARequest",
result_queue: "Queue", result_queue: "Queue",
...@@ -79,7 +80,7 @@ def _vllm_mp_worker( ...@@ -79,7 +80,7 @@ def _vllm_mp_worker(
try: try:
llm = LLM(**model_args) llm = LLM(**model_args)
res = llm.generate( res = llm.generate(
prompt_token_ids=requests, [TokensPrompt(prompt_token_ids=request) for request in requests],
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -114,30 +115,30 @@ class VLLM(TemplateLM): ...@@ -114,30 +115,30 @@ class VLLM(TemplateLM):
self, self,
pretrained: str, pretrained: str,
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto", dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
revision: Optional[str] = None, revision: str | None = None,
trust_remote_code: Optional[bool] = False, trust_remote_code: bool | None = False,
tokenizer: Optional[str] = None, tokenizer: str | None = None,
tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None, tokenizer_revision: str | None = None,
add_bos_token: Optional[bool] = False, add_bos_token: bool | None = False,
prefix_token_id: Optional[int] = None, prefix_token_id: int | None = None,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
quantization: Optional[str] = None, quantization: str | None = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
swap_space: int = 4, swap_space: int = 4,
batch_size: Union[str, int] = 1, batch_size: str | int = 1,
max_batch_size=None, max_batch_size: int | None = None,
max_length: int = None, max_length: int | None = None,
max_model_len: int = None, max_model_len: int | None = None,
seed: int = 1234, seed: int = 1234,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
data_parallel_size: int = 1, data_parallel_size: int = 1,
lora_local_path: str = None, lora_local_path: str | None = None,
# VLLM: enable thinking tags in the prompt. # VLLM: enable thinking tags in the prompt.
enable_thinking: bool = True, enable_thinking: bool = True,
chat_template_args: Optional[dict] = None, chat_template_args: Optional[dict] = None,
# End marker for thinking tags - splits to get response after this token (if provided). # End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None, think_end_token: str | None = None,
max_lora_rank: int = 16, max_lora_rank: int = 16,
**kwargs, **kwargs,
): ):
...@@ -173,7 +174,7 @@ class VLLM(TemplateLM): ...@@ -173,7 +174,7 @@ class VLLM(TemplateLM):
"swap_space": int(swap_space), "swap_space": int(swap_space),
"quantization": quantization, "quantization": quantization,
"seed": int(seed), "seed": int(seed),
"enable_lora": True if lora_local_path else False, "enable_lora": bool(lora_local_path),
"max_lora_rank": int(max_lora_rank), "max_lora_rank": int(max_lora_rank),
} }
self.model_args.update(kwargs) self.model_args.update(kwargs)
...@@ -196,6 +197,12 @@ class VLLM(TemplateLM): ...@@ -196,6 +197,12 @@ class VLLM(TemplateLM):
self.batch_size = "auto" self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.") eval_logger.info("Manual batching is not compatible with data parallelism.")
if "gemma" in pretrained.lower():
add_bos_token = True
eval_logger.info(
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
)
from transformers import AutoConfig from transformers import AutoConfig
self._config = AutoConfig.from_pretrained( self._config = AutoConfig.from_pretrained(
...@@ -214,11 +221,6 @@ class VLLM(TemplateLM): ...@@ -214,11 +221,6 @@ class VLLM(TemplateLM):
"enable_thinking", enable_thinking "enable_thinking", enable_thinking
) )
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
if "gemma" in pretrained.lower():
self.add_bos_token = True
eval_logger.info(
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
)
if parse_version(version("vllm")) >= parse_version("0.8.3"): if parse_version(version("vllm")) >= parse_version("0.8.3"):
kwargs_resolve_hf_chat_template = { kwargs_resolve_hf_chat_template = {
...@@ -239,13 +241,6 @@ class VLLM(TemplateLM): ...@@ -239,13 +241,6 @@ class VLLM(TemplateLM):
model_config = engine_args.create_model_config() model_config = engine_args.create_model_config()
kwargs_resolve_hf_chat_template["model_config"] = model_config kwargs_resolve_hf_chat_template["model_config"] = model_config
# https://github.com/vllm-project/vllm/pull/18259
if (
"trsut_remote_code"
in inspect.signature(resolve_hf_chat_template).parameters
):
kwargs_resolve_hf_chat_template["trsut_remote_code"] = trust_remote_code
else: else:
kwargs_resolve_hf_chat_template["trust_remote_code"] = trust_remote_code kwargs_resolve_hf_chat_template["trust_remote_code"] = trust_remote_code
...@@ -307,7 +302,7 @@ class VLLM(TemplateLM): ...@@ -307,7 +302,7 @@ class VLLM(TemplateLM):
return self._max_gen_toks return self._max_gen_toks
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str: ) -> str:
""" """
Method to apply a chat template to a list of chat history between user and model. Method to apply a chat template to a list of chat history between user and model.
...@@ -344,14 +339,14 @@ class VLLM(TemplateLM): ...@@ -344,14 +339,14 @@ class VLLM(TemplateLM):
def tok_encode( def tok_encode(
self, self,
string: Union[str, List[str]], string: str | list[str],
left_truncate_len: int = None, left_truncate_len: int = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
truncation: bool = False, truncation: bool = False,
) -> Union[List[int], List[List[int]]]: ) -> list[int] | list[list[int]]:
if not add_special_tokens: if not add_special_tokens:
add_special_tokens = False or self.add_bos_token add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer( encoding: list[list[int]] | list[int] = self.tokenizer(
string, string,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
truncation=truncation, truncation=truncation,
...@@ -369,19 +364,16 @@ class VLLM(TemplateLM): ...@@ -369,19 +364,16 @@ class VLLM(TemplateLM):
def _model_generate( def _model_generate(
self, self,
requests: List[List[int]] = None, requests: list[list[int]] = None,
generate: bool = False, generate: bool = False,
max_tokens: int = None, sampling_params: Union[List["SamplingParams"], "SamplingParams", None] = None,
stop: Optional[List[str]] = None,
**kwargs,
): ):
if generate: if not generate or sampling_params is None:
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else:
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
) )
if not isinstance(sampling_params, List):
sampling_params = [sampling_params] * len(requests)
if self.data_parallel_size > 1 and not self.V1: if self.data_parallel_size > 1 and not self.V1:
# vLLM hangs if resources are set in ray.remote # vLLM hangs if resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn # also seems to only work with decorator and not with ray.remote() fn
...@@ -389,13 +381,13 @@ class VLLM(TemplateLM): ...@@ -389,13 +381,13 @@ class VLLM(TemplateLM):
@ray.remote @ray.remote
def run_inference_one_model( def run_inference_one_model(
model_args: dict, model_args: dict,
sampling_params: SamplingParams, sampling_params: list["SamplingParams"],
requests: List[List[int]], requests: list[list[int]],
lora_request: LoRARequest, lora_request: "LoRARequest",
): ):
llm = LLM(**model_args) llm = LLM(**model_args)
return llm.generate( return llm.generate(
prompt_token_ids=requests, [TokensPrompt(prompt_token_ids=request) for request in requests],
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -403,9 +395,12 @@ class VLLM(TemplateLM): ...@@ -403,9 +395,12 @@ class VLLM(TemplateLM):
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers # interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)] requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
sampling_params = [
list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
]
inputs = ( inputs = (
(self.model_args, sampling_params, req, self.lora_request) (self.model_args, sp, req, self.lora_request)
for req in requests for req, sp in zip(requests, sampling_params)
) )
object_refs = [run_inference_one_model.remote(*x) for x in inputs] object_refs = [run_inference_one_model.remote(*x) for x in inputs]
results = ray.get(object_refs) results = ray.get(object_refs)
...@@ -420,16 +415,18 @@ class VLLM(TemplateLM): ...@@ -420,16 +415,18 @@ class VLLM(TemplateLM):
dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port() dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port()
requests = (list(x) for x in distribute(self.data_parallel_size, requests)) requests = (list(x) for x in distribute(self.data_parallel_size, requests))
sampling_params = (
list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
)
procs, resq = [], Queue() procs, resq = [], Queue()
# We use Process as it is non-daemonic # We use Process as it is non-daemonic
try: try:
for rank, req in enumerate(requests): for rank, (sp, req) in enumerate(zip(requests, sampling_params)):
proc = Process( proc = Process(
target=_vllm_mp_worker, target=_vllm_mp_worker,
args=( args=(
self.model_args.copy(), self.model_args.copy(),
sampling_params, sp,
req, req,
self.lora_request, self.lora_request,
resq, resq,
...@@ -459,7 +456,7 @@ class VLLM(TemplateLM): ...@@ -459,7 +456,7 @@ class VLLM(TemplateLM):
if dead_procs: if dead_procs:
raise RuntimeError( raise RuntimeError(
f"Worker processes {dead_procs} died unexpectedly" f"Worker processes {dead_procs} died unexpectedly"
) ) from None
continue continue
results = [rank_res[i] for i in range(len(procs))] results = [rank_res[i] for i in range(len(procs))]
...@@ -484,16 +481,16 @@ class VLLM(TemplateLM): ...@@ -484,16 +481,16 @@ class VLLM(TemplateLM):
else: else:
outputs = self.model.generate( outputs = self.model.generate(
prompt_token_ids=requests, [TokensPrompt(prompt_token_ids=request) for request in requests],
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False, use_tqdm=self.batch_size == "auto",
lora_request=self.lora_request, lora_request=self.lora_request,
) )
return outputs return outputs
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> list[float]:
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
adaptive_batch_size = len(requests) adaptive_batch_size = len(requests)
...@@ -508,7 +505,7 @@ class VLLM(TemplateLM): ...@@ -508,7 +505,7 @@ class VLLM(TemplateLM):
disable=(disable_tqdm or (self.rank != 0)), disable=(disable_tqdm or (self.rank != 0)),
) )
): ):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list( rolling_token_windows: list[tuple[list[int], list[int]]] = list(
map( map(
make_disjoint_window, make_disjoint_window,
get_rolling_token_windows( get_rolling_token_windows(
...@@ -561,13 +558,13 @@ class VLLM(TemplateLM): ...@@ -561,13 +558,13 @@ class VLLM(TemplateLM):
return loglikelihoods return loglikelihoods
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[str]: ) -> list[str]:
res = [] res = []
# batch tokenize contexts # batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests)) context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding: List[List[int]] = self.tok_encode( context_encoding: list[list[int]] = self.tok_encode(
context, add_special_tokens=self.add_bos_token context, add_special_tokens=self.add_bos_token
) )
requests = [ requests = [
...@@ -583,10 +580,11 @@ class VLLM(TemplateLM): ...@@ -583,10 +580,11 @@ class VLLM(TemplateLM):
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
return -len(_requests[0][1]), _requests[0][0] return -len(_requests[0][1]), _requests[0][0]
# we group requests by their generation_kwargs, re_ords = Collator(
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling requests,
# in the same batch. _collate_gen,
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs") group_by=None,
)
chunks = re_ords.get_batched( chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
...@@ -601,45 +599,48 @@ class VLLM(TemplateLM): ...@@ -601,45 +599,48 @@ class VLLM(TemplateLM):
for chunk in chunks: for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding) context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same context_encoding_truncated = []
# this is safe to assume because the `grouper` object ensures it. sampling_params = []
gen_kwargs = all_gen_kwargs[0] for x, gen_kwargs in zip(context_encoding, all_gen_kwargs):
# unpack our keyword arguments. # unpack our keyword arguments.
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences # add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc") # set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens # max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks max_ctx_len = self.max_length - max_gen_toks
all_lengths = [len(x) for x in context_encoding] if len(x) > max_ctx_len:
for length in all_lengths:
if length > max_ctx_len:
eval_logger.warning( eval_logger.warning(
f"Context length {length} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context." f"Context length {len(x)} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
) )
context_encoding = [x[-max_ctx_len:] for x in context_encoding] context_encoding_truncated.append(x[-max_ctx_len:])
else:
context_encoding_truncated.append(x)
# create sampling params
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params.append(
SamplingParams(max_tokens=max_gen_toks, stop=until, **kwargs)
)
# perform batched generation # perform batched generation
cont = self._model_generate( cont = self._model_generate(
requests=context_encoding, requests=context_encoding_truncated,
generate=True, generate=True,
max_tokens=max_gen_toks, sampling_params=sampling_params,
stop=until,
**kwargs,
) )
# cache generations # cache generations
for output, context in zip(cont, context): for output, context_ in zip(cont, context):
generated_text: str = output.outputs[0].text generated_text: str = output.outputs[0].text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
generated_text = postprocess_generated_text( generated_text = postprocess_generated_text(
...@@ -647,7 +648,7 @@ class VLLM(TemplateLM): ...@@ -647,7 +648,7 @@ class VLLM(TemplateLM):
) )
res.append(generated_text) res.append(generated_text)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text "generate_until", (context_, gen_kwargs), generated_text
) )
pbar.update(1) pbar.update(1)
...@@ -657,9 +658,9 @@ class VLLM(TemplateLM): ...@@ -657,9 +658,9 @@ class VLLM(TemplateLM):
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]], requests: list[tuple[tuple[str, str], list[int], list[int]]],
disable_tqdm: bool = False, disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]: ) -> list[tuple[float, bool]]:
res = [] res = []
def _collate(x): def _collate(x):
...@@ -680,7 +681,7 @@ class VLLM(TemplateLM): ...@@ -680,7 +681,7 @@ class VLLM(TemplateLM):
for chunk in chunks: for chunk in chunks:
inputs = [] inputs = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for _cache_key, context_enc, continuation_enc in chunk:
if ( if (
full_length := len(context_enc + continuation_enc) full_length := len(context_enc + continuation_enc)
) > self.max_length: ) > self.max_length:
...@@ -718,7 +719,7 @@ class VLLM(TemplateLM): ...@@ -718,7 +719,7 @@ class VLLM(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
@staticmethod @staticmethod
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]: def _parse_logprobs(tokens: list, outputs, ctxlen: int) -> tuple[float, bool]:
"""Process logprobs and tokens. """Process logprobs and tokens.
:param tokens: list :param tokens: list
......
# Tasks # Tasks
A list of supported tasks and task groupings can be viewed with `lm-eval --tasks list`. A list of supported tasks and task groupings can be viewed with `lm-eval --tasks list`.
For more information, including a full list of task names and their precise meanings or sources, follow the links provided to the individual README.md files for each subfolder. For more information, including a full list of task names and their precise meanings or sources, follow the links
provided to the individual README.md files for each subfolder.
| Task Family | Description | Language(s) | | Task Family | Description | Language(s) |
|--------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------| |--------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------|
| [eq-bench_es](eq_bench/README.md) | Spanish version of EQ-Bench (EN). Task for evaluating emotional reasoning through dialogue-based prompts. [Hugging Face](https://huggingface.co/datasets/BSC-LT/EQ-bench_es) |Spanish **Human Translated** |
| [eq-bench_ca](eq_bench/README.md) | Catalan version of EQ-Bench (EN). Task for evaluating emotional reasoning through dialogue-based prompts. [Hugging Face](https://huggingface.co/datasets/BSC-LT/EQ-bench_ca)| Catalan **Human Translated** |
| [aclue](aclue/README.md) | Tasks focusing on ancient Chinese language understanding and cultural aspects. | Ancient Chinese | | [aclue](aclue/README.md) | Tasks focusing on ancient Chinese language understanding and cultural aspects. | Ancient Chinese |
| [acp_bench](acpbench/README.md) | Tasks evaluating the reasoning ability about Action, Change, and Planning | English | | [acp_bench](acpbench/README.md) | Tasks evaluating the reasoning ability about Action, Change, and Planning | English |
| [acp_bench_hard](acpbench/README.md) | Tasks evaluating the reasoning ability about Action, Change, and Planning | English | | [acp_bench_hard](acpbench/README.md) | Tasks evaluating the reasoning ability about Action, Change, and Planning | English |
| [aexams](aexams/README.md) | Tasks in Arabic related to various academic exams covering a range of subjects. | Arabic | | [aexams](aexams/README.md) | Tasks in Arabic related to various academic exams covering a range of subjects. | Arabic |
| [agieval](agieval/README.md) | Tasks involving historical data or questions related to history and historical texts. | English, Chinese | | [agieval](agieval/README.md) | Tasks involving historical data or questions related to history and historical texts. | English, Chinese |
| [aime](aime/README.md) | High school math competition questions | English |
| [anli](anli/README.md) | Adversarial natural language inference tasks designed to test model robustness. | English | | [anli](anli/README.md) | Adversarial natural language inference tasks designed to test model robustness. | English |
| [arabic_leaderboard_complete](arabic_leaderboard_complete/README.md) | A full version of the tasks in the Open Arabic LLM Leaderboard, focusing on the evaluation of models that reflect the characteristics of Arabic language understanding and comprehension, culture, and heritage. Note that some of these tasks are machine-translated. | Arabic (Some MT) | | [arabic_leaderboard_complete](arabic_leaderboard_complete/README.md) | A full version of the tasks in the Open Arabic LLM Leaderboard, focusing on the evaluation of models that reflect the characteristics of Arabic language understanding and comprehension, culture, and heritage. Note that some of these tasks are machine-translated. | Arabic (Some MT) |
| [arabic_leaderboard_light](arabic_leaderboard_light/README.md) | A light version of the tasks in the Open Arabic LLM Leaderboard (i.e., 10% samples of the test set in the original benchmarks), focusing on the evaluation of models that reflect the characteristics of Arabic language understanding and comprehension, culture, and heritage. Note that some of these tasks are machine-translated. | Arabic (Some MT) | | [arabic_leaderboard_light](arabic_leaderboard_light/README.md) | A light version of the tasks in the Open Arabic LLM Leaderboard (i.e., 10% samples of the test set in the original benchmarks), focusing on the evaluation of models that reflect the characteristics of Arabic language understanding and comprehension, culture, and heritage. Note that some of these tasks are machine-translated. | Arabic (Some MT) |
| [arabicmmlu](arabicmmlu/README.md) | Localized Arabic version of MMLU with multiple-choice questions from 40 subjects. | Arabic | | [arabicmmlu](arabicmmlu/README.md) | Localized Arabic version of MMLU with multiple-choice questions from 40 subjects. | Arabic |
| [ArabCulture](arab_culture/README.md) | Benchmark for evaluating modeles' commonsense cultural knowledge across different 13 different Arab Countries. | Arabic | | [ArabCulture](arab_culture/README.md) | Benchmark for evaluating models' commonsense cultural knowledge across different 13 different Arab Countries. | Arabic |
| [AraDICE](aradice/README.md) | A collection of multiple tasks carefully designed to evaluate dialectal and cultural capabilities in large language models (LLMs). | Arabic | | [AraDICE](aradice/README.md) | A collection of multiple tasks carefully designed to evaluate dialectal and cultural capabilities in large language models (LLMs). | Arabic |
| [arc](arc/README.md) | Tasks involving complex reasoning over a diverse set of questions. | English | | [arc](arc/README.md) | Tasks involving complex reasoning over a diverse set of questions. | English |
| [arithmetic](arithmetic/README.md) | Tasks involving numerical computations and arithmetic reasoning. | English | | [arithmetic](arithmetic/README.md) | Tasks involving numerical computations and arithmetic reasoning. | English |
| [asdiv](asdiv/README.md) | Tasks involving arithmetic and mathematical reasoning challenges. | English | | [asdiv](asdiv/README.md) | Tasks involving arithmetic and mathematical reasoning challenges. | English |
| [babi](babi/README.md) | Tasks designed as question and answering challenges based on simulated stories. | English | | [babi](babi/README.md) | Tasks designed as question and answering challenges based on simulated stories. | English |
| [babilong](babilong/README.md) | Tasks designed to test whether models can find and reason over facts in long contexts. | English |
| [basque_bench](basque_bench/README.md) | Collection of tasks in Basque encompassing various evaluation areas. | Basque | | [basque_bench](basque_bench/README.md) | Collection of tasks in Basque encompassing various evaluation areas. | Basque |
| [basqueglue](basqueglue/README.md) | Tasks designed to evaluate language understanding in Basque language. | Basque | | [basqueglue](basqueglue/README.md) | Tasks designed to evaluate language understanding in Basque language. | Basque |
| [bbh](bbh/README.md) | Tasks focused on deep semantic understanding through hypothesization and reasoning. | English, German | | [bbh](bbh/README.md) | Tasks focused on deep semantic understanding through hypothesization and reasoning. | English, German |
...@@ -29,30 +33,36 @@ ...@@ -29,30 +33,36 @@
| [belebele](belebele/README.md) | Language understanding tasks in a variety of languages and scripts. | Multiple (122 languages) | | [belebele](belebele/README.md) | Language understanding tasks in a variety of languages and scripts. | Multiple (122 languages) |
| benchmarks | General benchmarking tasks that test a wide range of language understanding capabilities. | | | benchmarks | General benchmarking tasks that test a wide range of language understanding capabilities. | |
| [bertaqa](bertaqa/README.md) | Local Basque cultural trivia QA tests in English and Basque languages. | English, Basque, Basque (MT) | | [bertaqa](bertaqa/README.md) | Local Basque cultural trivia QA tests in English and Basque languages. | English, Basque, Basque (MT) |
| [bhs](bhs/README.md) | Grammatical knowledge evaluation for low-resource langauges. | Basque, Hindi, Swahili |
| [bigbench](bigbench/README.md) | Broad tasks from the BIG-bench benchmark designed to push the boundaries of large models. | Multiple | | [bigbench](bigbench/README.md) | Broad tasks from the BIG-bench benchmark designed to push the boundaries of large models. | Multiple |
| [blimp](blimp/README.md) | Tasks testing grammatical phenomena to evaluate language model's linguistic capabilities. | English | | [blimp](blimp/README.md) | Tasks testing grammatical phenomena to evaluate language model's linguistic capabilities. | English |
| [blimp_nl](blimp_nl/README.md) | A benchmark evaluating language models' grammatical capabilities in Dutch based on comparing the probabilities of minimal pairs of grammatical and ungrammatical sentences. | Dutch |
| [c4](c4/README.md) | Tasks based on a colossal, cleaned version of Common Crawl's web crawl corpus to assess models' language modeling capabilities. | English | | [c4](c4/README.md) | Tasks based on a colossal, cleaned version of Common Crawl's web crawl corpus to assess models' language modeling capabilities. | English |
| [cabbq](cabbq/README.md) | Adaptation of the [BBQ](bbq/README.md) benchmark to the Catalan language and stereotypes prevalent in Spain. | Catalan |
| [careqa](careqa/README.md) | Multiple choice and open-ended medical question answering based on the Spanish Specialised Healthcare Training (MIR) exams. | English, Spanish | | [careqa](careqa/README.md) | Multiple choice and open-ended medical question answering based on the Spanish Specialised Healthcare Training (MIR) exams. | English, Spanish |
| [catalan_bench](catalan_bench/README.md) | Collection of tasks in Catalan encompassing various evaluation areas. | Catalan | | [catalan_bench](catalan_bench/README.md) | Collection of tasks in Catalan encompassing various evaluation areas. | Catalan |
| [ceval](ceval/README.md) | Tasks that evaluate language understanding and reasoning in an educational context. | Chinese | | [ceval](ceval/README.md) | Tasks that evaluate language understanding and reasoning in an educational context. | Chinese |
| [cmmlu](cmmlu/README.md) | Multi-subject multiple choice question tasks for comprehensive academic assessment. | Chinese | | [cmmlu](cmmlu/README.md) | Multi-subject multiple choice question tasks for comprehensive academic assessment. | Chinese |
| code_x_glue | Tasks that involve understanding and generating code across multiple programming languages. | Go, Java, JS, PHP, Python, Ruby | | code_x_glue | Tasks that involve understanding and generating code across multiple programming languages. | Go, Java, JS, PHP, Python, Ruby |
| [commonsense_qa](commonsense_qa/README.md) | CommonsenseQA, a multiple-choice QA dataset for measuring commonsense knowledge. | English | | [commonsense_qa](commonsense_qa/README.md) | CommonsenseQA, a multiple-choice QA dataset for measuring commonsense knowledge. | English |
| [copal_id](copal_id/README.md) United States | Indonesian causal commonsense reasoning dataset that captures local nuances. | Indonesian | | [copal_id](copal_id/README.md) United States | Indonesian causal commonsense reasoning dataset that captures local nuances. | Indonesian |
| [coqa](coqa/README.md) | Conversational question answering tasks to test dialog understanding. | English | | [coqa](coqa/README.md) | Conversational question answering tasks to test dialog understanding. | English |
| [crows_pairs](crows_pairs/README.md) | Tasks designed to test model biases in various sociodemographic groups. | English, French | | [crows_pairs](crows_pairs/README.md) | Tasks designed to test model biases in various sociodemographic groups. | English, French |
| [click](click/README.md) | A benchmark dataset of Cultural and Linguistic Intelligence in Korean (CLIcK), comprising 1,995 QA pairs sourced from official Korean exams and textbooks to test Korean cultural and linguistic knowledge. | Korean |
| csatqa | Tasks related to SAT and other standardized testing questions for academic assessment. | Korean | | csatqa | Tasks related to SAT and other standardized testing questions for academic assessment. | Korean |
| [darija_bench](darija_bench/README.md) | Traditional NLP tasks (Translation, Summariation, etc..) for Moroccan Darija | Moroccan Darija (some MT) | | [darija_bench](darija_bench/README.md) | Traditional NLP tasks (Translation, Summarization, etc..) for Moroccan Darija | Moroccan Darija (some MT) |
| [darijahellaswag](darijahellaswag/README.md) | Moroccan Darija version of HellaSwag. | Moroccan Darija (MT) | | [darijahellaswag](darijahellaswag/README.md) | Moroccan Darija version of HellaSwag. | Moroccan Darija (MT) |
| [darijammlu](darijammlu/README.md) | Multiple-choice QA in Moroccan Darija (an Arabic dialect). | Moroccan Darija (MT) | | [darijammlu](darijammlu/README.md) | Multiple-choice QA in Moroccan Darija (an Arabic dialect). | Moroccan Darija (MT) |
| [discrim_eval](discrim_eval/README.md) | Prompts for binary decisions covering 70 scenarios to evaluate demographic bias. | English |
| [drop](drop/README.md) | Tasks requiring numerical reasoning, reading comprehension, and question answering. | English | | [drop](drop/README.md) | Tasks requiring numerical reasoning, reading comprehension, and question answering. | English |
| [egyhellaswag](egyhellaswag/README.md) | Egyptian Arabic (Masri) version of HellaSwag. | Egyptian Arabic (MT) | | [egyhellaswag](egyhellaswag/README.md) | Egyptian Arabic (Masri) version of HellaSwag. | Egyptian Arabic (MT) |
| [egymmlu](egymmlu/README.md) | Multiple-choice QA in Egyptian Arabic. | Egyptian Arabic (MT) | | [egymmlu](egymmlu/README.md) | Multiple-choice QA in Egyptian Arabic. | Egyptian Arabic (MT) |
| [eq_bench](eq_bench/README.md) | Tasks focused on equality and ethics in question answering and decision-making. | English | | [eq_bench](eq_bench/README.md) | Tasks focused on equality and ethics in question answering and decision-making. | English |
| [esbbq](esbbq/README.md) | Adaptation of the [BBQ](bbq/README.md) benchmark to the Spanish language and stereotypes prevalent in Spain. | Spanish |
| [eus_exams](eus_exams/README.md) | Tasks based on various professional and academic exams in the Basque language. | Basque | | [eus_exams](eus_exams/README.md) | Tasks based on various professional and academic exams in the Basque language. | Basque |
| [eus_proficiency](eus_proficiency/README.md) | Tasks designed to test proficiency in the Basque language across various topics. | Basque | | [eus_proficiency](eus_proficiency/README.md) | Tasks designed to test proficiency in the Basque language across various topics. | Basque |
| [eus_reading](eus_reading/README.md) | Reading comprehension tasks specifically designed for the Basque language. | Basque | | [eus_reading](eus_reading/README.md) | Reading comprehension tasks specifically designed for the Basque language. | Basque |
| [eus_trivia](eus_trivia/README.md) | Trivia and knowledge testing tasks in the Basque language. | Basque | | [eus_trivia](eus_trivia/README.md) | Trivia atypicnd knowledge testing tasks in the Basque language. | Basque |
| [evalita_LLM](evalita_llm/README.md) | A native Italian benchmark with diverse tasks formats and multiple prompts. | Italian | | [evalita_LLM](evalita_llm/README.md) | A native Italian benchmark with diverse tasks formats and multiple prompts. | Italian |
| [fda](fda/README.md) | Tasks for extracting key-value pairs from FDA documents to test information extraction. | English | | [fda](fda/README.md) | Tasks for extracting key-value pairs from FDA documents to test information extraction. | English |
| [fld](fld/README.md) | Tasks involving free-form and directed dialogue understanding. | English | | [fld](fld/README.md) | Tasks involving free-form and directed dialogue understanding. | English |
...@@ -71,13 +81,15 @@ ...@@ -71,13 +81,15 @@
| [histoires_morales](histoires_morales/README.md) | A dataset of structured narratives that describe normative and norm-divergent actions taken by individuals to accomplish certain intentions in concrete situations. | French (Some MT) | | [histoires_morales](histoires_morales/README.md) | A dataset of structured narratives that describe normative and norm-divergent actions taken by individuals to accomplish certain intentions in concrete situations. | French (Some MT) |
| [hrm8k](hrm8k/README.md) | A challenging bilingual math reasoning benchmark for Korean and English. | Korean (Some MT), English (Some MT) | | [hrm8k](hrm8k/README.md) | A challenging bilingual math reasoning benchmark for Korean and English. | Korean (Some MT), English (Some MT) |
| [humaneval](humaneval/README.md) | Code generation task that measure functional correctness for synthesizing programs from docstrings. | Python | | [humaneval](humaneval/README.md) | Code generation task that measure functional correctness for synthesizing programs from docstrings. | Python |
| [humaneval_infilling](humaneval_infilling/README.md) | Code generation task that measure fill-in-the-middle capability for synthesizing programs from docstrings. | Python |
| [icelandic_winogrande](icelandic_winogrande/README.md) | Manually translated and localized version of the [WinoGrande](winogrande/README.md) commonsense reasoning benchmark for Icelandic. | Icelandic |
| [ifeval](ifeval/README.md) | Interactive fiction evaluation tasks for narrative understanding and reasoning. | English | | [ifeval](ifeval/README.md) | Interactive fiction evaluation tasks for narrative understanding and reasoning. | English |
| [inverse_scaling](inverse_scaling/README.md) | Multiple-choice tasks from the Inverse Scaling Prize, designed to find settings where larger language models perform worse. | English | | [inverse_scaling](inverse_scaling/README.md) | Multiple-choice tasks from the Inverse Scaling Prize, designed to find settings where larger language models perform worse. | English |
| [japanese_leaderboard](japanese_leaderboard/README.md) | Japanese language understanding tasks to benchmark model performance on various linguistic aspects. | Japanese | | [japanese_leaderboard](japanese_leaderboard/README.md) | Japanese language understanding tasks to benchmark model performance on various linguistic aspects. | Japanese |
| [jsonschema_bench](jsonschema_bench/README.md) | Evaluate the ability of LLMs to generate JSON objects that conform to a given JSON schema, including API, configuration files, and other structured data formats. | JSON | | [jsonschema_bench](jsonschema_bench/README.md) | Evaluate the ability of LLMs to generate JSON objects that conform to a given JSON schema, including API, configuration files, and other structured data formats. | JSON |
| [kbl](kbl/README.md) | Korean Benchmark for Legal Language Understanding. | Korean | | [kbl](kbl/README.md) | Korean Benchmark for Legal Language Understanding. | Korean |
| [kmmlu](kmmlu/README.md) | Knowledge-based multi-subject multiple choice questions for academic evaluation. | Korean | | [kmmlu](kmmlu/README.md) | Knowledge-based multi-subject multiple choice questions for academic evaluation. | Korean |
| [kobest](kobest/README.md) | A collection of tasks designed to evaluate understanding in Korean language. | Korean | | [kobest](kobest/README.md) | A collection of tasks designed to evaluate understanding in Korean language{Fecha: language. | Korean |
| [kormedmcqa](kormedmcqa/README.md) | Medical question answering tasks in Korean to test specialized domain knowledge. | Korean | | [kormedmcqa](kormedmcqa/README.md) | Medical question answering tasks in Korean to test specialized domain knowledge. | Korean |
| [lambada](lambada/README.md) | Tasks designed to predict the endings of text passages, testing language prediction skills. | English | | [lambada](lambada/README.md) | Tasks designed to predict the endings of text passages, testing language prediction skills. | English |
| [lambada_cloze](lambada_cloze/README.md) | Cloze-style LAMBADA dataset. | English | | [lambada_cloze](lambada_cloze/README.md) | Cloze-style LAMBADA dataset. | English |
...@@ -85,9 +97,12 @@ ...@@ -85,9 +97,12 @@
| [lambada_multilingual_stablelm](lambada_multilingual_stablelm/README.md) | Multilingual LAMBADA dataset. Users should prefer evaluating on this version of the multilingual dataset instead of on `lambada_multilingual`. | German, English, Spanish, French, Italian, Dutch, Portuguese | | [lambada_multilingual_stablelm](lambada_multilingual_stablelm/README.md) | Multilingual LAMBADA dataset. Users should prefer evaluating on this version of the multilingual dataset instead of on `lambada_multilingual`. | German, English, Spanish, French, Italian, Dutch, Portuguese |
| [leaderboard](leaderboard/README.md) | Task group used by Hugging Face's [Open LLM Leaderboard v2](https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard). Those tasks are static and will not change through time | English | | [leaderboard](leaderboard/README.md) | Task group used by Hugging Face's [Open LLM Leaderboard v2](https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard). Those tasks are static and will not change through time | English |
| [lingoly](lingoly/README.md) | Challenging logical reasoning benchmark in low-resource languages with controls for memorization | English, Multilingual | | [lingoly](lingoly/README.md) | Challenging logical reasoning benchmark in low-resource languages with controls for memorization | English, Multilingual |
| [libra](libra/README.md) | Evaluates long-context understanding in Russian across four complexity levels | Russian (MT) | | [llama3](llama3/README.md) | Evals reproducing those provided by the LLAMA team in the Hugging Face repo (instruct) | English, Multilingual |
| [libra](libra/README.md) | Evaluates long-context understanding in Russian across four complexity levels | Russian (MT) |
| [lm_syneval](lm_syneval/README.md) | Evaluates the syntactic capabilities of language models. | English |
| [logiqa](logiqa/README.md) | Logical reasoning tasks requiring advanced inference and deduction. | English, Chinese | | [logiqa](logiqa/README.md) | Logical reasoning tasks requiring advanced inference and deduction. | English, Chinese |
| [logiqa2](logiqa2/README.md) | Large-scale logical reasoning dataset adapted from the Chinese Civil Service Examination. | English, Chinese | | [logiqa2](logiqa2/README.md) | Large-scale logical reasoning dataset adapted from the Chinese Civil Service Examination. | English, Chinese |
| [longbench](longbench/README.md) | LongBench evaluates language models' ability to understand lengthy texts across multiple tasks and languages. | English, Chinese |
| [mastermind](mastermind/README.md) | Reasoning benchmark based on the board game of Mastermind. | English | | [mastermind](mastermind/README.md) | Reasoning benchmark based on the board game of Mastermind. | English |
| [mathqa](mathqa/README.md) | Question answering tasks involving mathematical reasoning and problem-solving. | English | | [mathqa](mathqa/README.md) | Question answering tasks involving mathematical reasoning and problem-solving. | English |
| [mbpp](mbpp/README.md) | A benchmark designed to measure the ability to synthesize short Python programs from natural language descriptions. | Python | | [mbpp](mbpp/README.md) | A benchmark designed to measure the ability to synthesize short Python programs from natural language descriptions. | Python |
...@@ -105,9 +120,11 @@ ...@@ -105,9 +120,11 @@
| [minerva_math](minerva_math/README.md) | Mathematics-focused tasks requiring numerical reasoning and problem-solving skills. | English | | [minerva_math](minerva_math/README.md) | Mathematics-focused tasks requiring numerical reasoning and problem-solving skills. | English |
| [mlqa](mlqa/README.md) | MultiLingual Question Answering benchmark dataset for evaluating cross-lingual question answering performance. | English, Arabic, German, Spanish, Hindi, Vietnamese, Simplified Chinese | | [mlqa](mlqa/README.md) | MultiLingual Question Answering benchmark dataset for evaluating cross-lingual question answering performance. | English, Arabic, German, Spanish, Hindi, Vietnamese, Simplified Chinese |
| [mmlu](mmlu/README.md) | Massive Multitask Language Understanding benchmark for broad domain language evaluation. Several variants are supported. | English | | [mmlu](mmlu/README.md) | Massive Multitask Language Understanding benchmark for broad domain language evaluation. Several variants are supported. | English |
| [mmlu_redux](mmlu-redux/README.md) | Refined Massive Multitask Language Understanding benchmark for broad domain evaluation with improved data quality. | English |
| [mmlu_redux](mmlu-redux-spanish/README.md) | Refined Massive Multitask Language Understanding benchmark for broad domain evaluation with improved data quality. | Spanish |
| [mmlu_pro](mmlu_pro/README.md) | A refined set of MMLU, integrating more challenging, reasoning-focused questions and expanding the choice set from four to ten options. | English | | [mmlu_pro](mmlu_pro/README.md) | A refined set of MMLU, integrating more challenging, reasoning-focused questions and expanding the choice set from four to ten options. | English |
| [mmlu-pro-plus](mmlu-pro-plus/README.md) | A new test set for evaluating shortcut learning and higher-order reasoning of LLMs. | English | | [mmlu-pro-plus](mmlu-pro-plus/README.md) | A new test set for evaluating shortcut learning and higher-order reasoning of LLMs. | English |
| [mmlu_prox](mmlu_prox/README.md) | A multilingual benchmark that extends MMLU-Pro to multiple typologically diverse languages with human validation. | English, Japanese, Chinese, Korean, French, German, Spanish, Portuguese, Swahili, Thai, Arabic, Hindi, Bengali | | [mmlu_prox](mmlu_prox/README.md) | A multilingual benchmark that extends MMLU-Pro to multiple typologically diverse languages with human validation. | English, Japanese, Chinese, Korean, French, German, Spanish, Portuguese, Zulu, Swahili, Wolof, Yoruba, Thai, Arabic, Hindi, Bengali, Serbian, Hungarian, Vietnamese, Czech, Marathi, Afrikaans, Nepali, Telugu, Urdu, Russian, Indonesian, Italian, Ukrainian|
| [mmlusr](mmlusr/README.md) | Variation of MMLU designed to be more rigorous. | English | | [mmlusr](mmlusr/README.md) | Variation of MMLU designed to be more rigorous. | English |
| model_written_evals | Evaluation tasks auto-generated for evaluating a collection of AI Safety concerns. | | | model_written_evals | Evaluation tasks auto-generated for evaluating a collection of AI Safety concerns. | |
| [moral_stories](moral_stories/README.md) | A crowd-sourced dataset of structured narratives that describe normative and norm-divergent actions taken by individuals to accomplish certain intentions in concrete situations. | English | | [moral_stories](moral_stories/README.md) | A crowd-sourced dataset of structured narratives that describe normative and norm-divergent actions taken by individuals to accomplish certain intentions in concrete situations. | English |
...@@ -156,6 +173,7 @@ ...@@ -156,6 +173,7 @@
| [truthfulqa](truthfulqa/README.md) | A QA task aimed at evaluating the truthfulness and factual accuracy of model responses. | English | | [truthfulqa](truthfulqa/README.md) | A QA task aimed at evaluating the truthfulness and factual accuracy of model responses. | English |
| [truthfulqa-multi](truthfulqa-multi/README.md) | Is a multilingual version of TruthfulQA, a QA task aimed at evaluating the truthfulness and factual accuracy of model responses. | English, Spanish, Catalan, Basque, Galician | | [truthfulqa-multi](truthfulqa-multi/README.md) | Is a multilingual version of TruthfulQA, a QA task aimed at evaluating the truthfulness and factual accuracy of model responses. | English, Spanish, Catalan, Basque, Galician |
| [turkishmmlu](turkishmmlu/README.md) | A multiple-choice QA test modeled after MMLU, written in Turkish based on Turkish high-school level exams. | Turkish | | [turkishmmlu](turkishmmlu/README.md) | A multiple-choice QA test modeled after MMLU, written in Turkish based on Turkish high-school level exams. | Turkish |
| [turblimp_core](turblimp/README.md) | A benchmark evaluating language models' grammatical capabilities in Turkish based on comparing the probabilities of minimal pairs of grammatical and ungrammatical sentences. | Turkish |
| [unitxt](unitxt/README.md) | A number of tasks implemented using the unitxt library for flexible, shareable, and reusable data preparation and evaluation for generative AI. | English | | [unitxt](unitxt/README.md) | A number of tasks implemented using the unitxt library for flexible, shareable, and reusable data preparation and evaluation for generative AI. | English |
| [unscramble](unscramble/README.md) | Tasks involving the rearrangement of scrambled sentences to test syntactic understanding. | English | | [unscramble](unscramble/README.md) | Tasks involving the rearrangement of scrambled sentences to test syntactic understanding. | English |
| [webqs](webqs/README.md) | Web-based question answering tasks designed to evaluate internet search and retrieval. | English | | [webqs](webqs/README.md) | Web-based question answering tasks designed to evaluate internet search and retrieval. | English |
...@@ -171,9 +189,11 @@ ...@@ -171,9 +189,11 @@
| [xquad](xquad/README.md) | Cross-lingual Question Answering Dataset in multiple languages. | Arabic, German, Greek, English, Spanish, Hindi, Romanian, Russian, Thai, Turkish, Vietnamese, Chinese | | [xquad](xquad/README.md) | Cross-lingual Question Answering Dataset in multiple languages. | Arabic, German, Greek, English, Spanish, Hindi, Romanian, Russian, Thai, Turkish, Vietnamese, Chinese |
| [xstorycloze](xstorycloze/README.md) | Cross-lingual narrative understanding tasks to predict story endings in multiple languages. | Russian, Simplified Chinese, Spanish, Arabic, Hindi, Indonesian, Telugu, Swahili, Basque, Burmese | | [xstorycloze](xstorycloze/README.md) | Cross-lingual narrative understanding tasks to predict story endings in multiple languages. | Russian, Simplified Chinese, Spanish, Arabic, Hindi, Indonesian, Telugu, Swahili, Basque, Burmese |
| [xwinograd](xwinograd/README.md) | Cross-lingual Winograd schema tasks for coreference resolution in multiple languages. | English, French, Japanese, Portuguese, Russian, Chinese | | [xwinograd](xwinograd/README.md) | Cross-lingual Winograd schema tasks for coreference resolution in multiple languages. | English, French, Japanese, Portuguese, Russian, Chinese |
| [zhoblimp](zhoblimp/README.md) | A benchmark evaluating language models' grammatical capabilities in Chinese based on comparing the probabilities of minimal pairs of grammatical and ungrammatical sentences. | Chinese |
## Multimodal Tasks ## Multimodal Tasks
| Task Family | Description | Modality | | Task Family | Description | Modality |
|------------------------------|---------------------------------------------------------------------------------------------------------|-------------| | ---------------------------- | ------------------------------------------------------------------------------------------------------- | ----------- |
| [chartqa](chartqa/README.md) | A benchmark for question answering about charts that requires both visual and logical reasoning. | Image, Text | | [chartqa](chartqa/README.md) | A benchmark for question answering about charts that requires both visual and logical reasoning. | Image, Text |
| [mmmu](mmmu/README.md) | Evaluate multimodal models on massive multi-discipline tasks demanding college-level subject knowledge. | Image, Text | | [mmmu](mmmu/README.md) | Evaluate multimodal models on massive multi-discipline tasks demanding college-level subject knowledge. | Image, Text |
...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor): ...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None self.indexes = None
class ACPGrammarParser(object): class ACPGrammarParser:
def __init__(self, task) -> None: def __init__(self, task) -> None:
self.task = task self.task = task
with open(GRAMMAR_FILE) as f: with open(GRAMMAR_FILE) as f:
...@@ -556,8 +556,8 @@ class STRIPS: ...@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret]) return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s): def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower()) d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower()) p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static()) new_state = get_atoms_pddl(d, p, s | self.get_static())
......
...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor): ...@@ -81,7 +81,7 @@ class ACPBench_Visitor(Visitor):
self.indexes = None self.indexes = None
class ACPGrammarParser(object): class ACPGrammarParser:
def __init__(self, task) -> None: def __init__(self, task) -> None:
self.task = task self.task = task
with open(GRAMMAR_FILE) as f: with open(GRAMMAR_FILE) as f:
...@@ -556,8 +556,8 @@ class STRIPS: ...@@ -556,8 +556,8 @@ class STRIPS:
return set([fix_name(str(x)) for x in ret]) return set([fix_name(str(x)) for x in ret])
def PDDL_replace_init_pddl_parser(self, s): def PDDL_replace_init_pddl_parser(self, s):
d = DomainParser()(open(self.domain_file, "r").read().lower()) d = DomainParser()(open(self.domain_file).read().lower())
p = ProblemParser()(open(self.problem_file, "r").read().lower()) p = ProblemParser()(open(self.problem_file).read().lower())
new_state = get_atoms_pddl(d, p, s | self.get_static()) new_state = get_atoms_pddl(d, p, s | self.get_static())
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_1 - adr_prompt_1
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_2 - adr_prompt_2
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_3 - adr_prompt_3
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_4 - adr_prompt_4
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_5 - adr_prompt_5
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -4,7 +4,6 @@ tag: ...@@ -4,7 +4,6 @@ tag:
task: null task: null
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
...@@ -3,7 +3,6 @@ tag: ...@@ -3,7 +3,6 @@ tag:
- afrisent_prompt_2 - afrisent_prompt_2
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
...@@ -3,7 +3,6 @@ tag: ...@@ -3,7 +3,6 @@ tag:
- afrisenti_prompt_3 - afrisenti_prompt_3
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
...@@ -3,7 +3,6 @@ tag: ...@@ -3,7 +3,6 @@ tag:
- afrisenti_prompt_4 - afrisenti_prompt_4
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
...@@ -3,7 +3,6 @@ tag: ...@@ -3,7 +3,6 @@ tag:
- afrisenti_prompt_5 - afrisenti_prompt_5
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
...@@ -73,3 +73,5 @@ HomePage: https://github.com/masakhane-io/masakhane-pos ...@@ -73,3 +73,5 @@ HomePage: https://github.com/masakhane-io/masakhane-pos
abstract = "In this paper, we present AfricaPOS, the largest part-of-speech (POS) dataset for 20 typologically diverse African languages. We discuss the challenges in annotating POS for these languages using the universal dependencies (UD) guidelines. We conducted extensive POS baseline experiments using both conditional random field and several multilingual pre-trained language models. We applied various cross-lingual transfer models trained with data available in the UD. Evaluating on the AfricaPOS dataset, we show that choosing the best transfer language(s) in both single-source and multi-source setups greatly improves the POS tagging performance of the target languages, in particular when combined with parameter-fine-tuning methods. Crucially, transferring knowledge from a language that matches the language family and morphosyntactic properties seems to be more effective for POS tagging in unseen languages." abstract = "In this paper, we present AfricaPOS, the largest part-of-speech (POS) dataset for 20 typologically diverse African languages. We discuss the challenges in annotating POS for these languages using the universal dependencies (UD) guidelines. We conducted extensive POS baseline experiments using both conditional random field and several multilingual pre-trained language models. We applied various cross-lingual transfer models trained with data available in the UD. Evaluating on the AfricaPOS dataset, we show that choosing the best transfer language(s) in both single-source and multi-source setups greatly improves the POS tagging performance of the target languages, in particular when combined with parameter-fine-tuning methods. Crucially, transferring knowledge from a language that matches the language family and morphosyntactic properties seems to be more effective for POS tagging in unseen languages."
} }
``` ```
## Changelog
- 2025-07-21: Refactored. Scores should not be affected.
...@@ -14,19 +14,18 @@ validation_split: validation ...@@ -14,19 +14,18 @@ validation_split: validation
test_split: test test_split: test
fewshot_split: train fewshot_split: train
doc_to_target: !function utils.doc_to_target doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:" doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
filter_list: filter_list:
- filter: - filter:
- function: regex_pos - function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract name: flexible-extract
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: !function utils.acc_score aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata: metadata:
version: 1.0 version: 1.0
from itertools import chain import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc): def doc_to_target(doc):
pos_tag_map = { pos_tag_map = {
...@@ -29,27 +29,40 @@ def doc_to_target(doc): ...@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]] return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items): def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
unzipped_list = list(zip(*items)) def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1] filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists return filtered_resps
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list def process_results(doc: dict[str, Any], results: list[list[str]]):
accuracy = accuracy_score(gold, pred) golds, preds = doc_to_target(doc), results[0]
accuracy_scores.append(accuracy) # Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = ( return {"acc": accuracy}
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment