Commit 930b4253 authored by Baber's avatar Baber
Browse files

Merge branch 'smolrefact' into lazy_reg

# Conflicts:
#	lm_eval/__init__.py
#	lm_eval/api/metrics.py
#	lm_eval/api/registry.py
#	lm_eval/api/task.py
#	lm_eval/filters/__init__.py
#	pyproject.toml
parents d547b663 73202a2e
from itertools import chain import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc): def doc_to_target(doc):
pos_tag_map = { pos_tag_map = {
...@@ -29,27 +29,40 @@ def doc_to_target(doc): ...@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]] return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items): def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
unzipped_list = list(zip(*items)) def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1] filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists return filtered_resps
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list def process_results(doc: dict[str, Any], results: list[list[str]]):
accuracy = accuracy_score(gold, pred) golds, preds = doc_to_target(doc), results[0]
accuracy_scores.append(accuracy) # Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = ( return {"acc": accuracy}
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
...@@ -16,17 +16,16 @@ fewshot_split: train ...@@ -16,17 +16,16 @@ fewshot_split: train
doc_to_target: !function utils.doc_to_target doc_to_target: !function utils.doc_to_target
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:" doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
process_results: !function utils.process_results
filter_list: filter_list:
- filter: - filter:
- function: regex_pos - function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract name: flexible-extract
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: !function utils.acc_score aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata: metadata:
version: 1.0 version: 1.0
from itertools import chain import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc): def doc_to_target(doc):
pos_tag_map = { pos_tag_map = {
...@@ -29,27 +29,40 @@ def doc_to_target(doc): ...@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]] return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items): def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
unzipped_list = list(zip(*items)) def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1] filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists return filtered_resps
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list def process_results(doc: dict[str, Any], results: list[list[str]]):
accuracy = accuracy_score(gold, pred) golds, preds = doc_to_target(doc), results[0]
accuracy_scores.append(accuracy) # Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = ( return {"acc": accuracy}
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
...@@ -16,17 +16,16 @@ fewshot_split: train ...@@ -16,17 +16,16 @@ fewshot_split: train
doc_to_target: !function utils.doc_to_target doc_to_target: !function utils.doc_to_target
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:" doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
process_results: !function utils.process_results
filter_list: filter_list:
- filter: - filter:
- function: regex_pos - function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract name: flexible-extract
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: !function utils.acc_score aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata: metadata:
version: 1.0 version: 1.0
from itertools import chain import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc): def doc_to_target(doc):
pos_tag_map = { pos_tag_map = {
...@@ -29,27 +29,40 @@ def doc_to_target(doc): ...@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]] return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items): def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
unzipped_list = list(zip(*items)) def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1] filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists return filtered_resps
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list def process_results(doc: dict[str, Any], results: list[list[str]]):
accuracy = accuracy_score(gold, pred) golds, preds = doc_to_target(doc), results[0]
accuracy_scores.append(accuracy) # Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = ( return {"acc": accuracy}
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
...@@ -16,17 +16,16 @@ fewshot_split: train ...@@ -16,17 +16,16 @@ fewshot_split: train
doc_to_target: !function utils.doc_to_target doc_to_target: !function utils.doc_to_target
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:" doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
process_results: !function utils.process_results
filter_list: filter_list:
- filter: - filter:
- function: regex_pos - function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract name: flexible-extract
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: !function utils.acc_score aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata: metadata:
version: 1.0 version: 1.0
from itertools import chain import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc): def doc_to_target(doc):
pos_tag_map = { pos_tag_map = {
...@@ -29,27 +29,40 @@ def doc_to_target(doc): ...@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]] return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items): def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
unzipped_list = list(zip(*items)) def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1] filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists return filtered_resps
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list def process_results(doc: dict[str, Any], results: list[list[str]]):
accuracy = accuracy_score(gold, pred) golds, preds = doc_to_target(doc), results[0]
accuracy_scores.append(accuracy) # Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = ( return {"acc": accuracy}
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
...@@ -16,17 +16,16 @@ fewshot_split: train ...@@ -16,17 +16,16 @@ fewshot_split: train
doc_to_target: !function utils.doc_to_target doc_to_target: !function utils.doc_to_target
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:" doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
process_results: !function utils.process_results
filter_list: filter_list:
- filter: - filter:
- function: regex_pos - function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract name: flexible-extract
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: !function utils.acc_score aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata: metadata:
version: 1.0 version: 1.0
from itertools import chain import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc): def doc_to_target(doc):
pos_tag_map = { pos_tag_map = {
...@@ -29,27 +29,40 @@ def doc_to_target(doc): ...@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]] return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items): def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
unzipped_list = list(zip(*items)) def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1] filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists return filtered_resps
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list def process_results(doc: dict[str, Any], results: list[list[str]]):
accuracy = accuracy_score(gold, pred) golds, preds = doc_to_target(doc), results[0]
accuracy_scores.append(accuracy) # Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = ( return {"acc": accuracy}
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
from lm_eval.utils import weighted_f1_score
def doc_to_text(doc): def doc_to_text(doc):
output = """Please provide the POS tags for each word in the input sentence. The input will be a list of words in output = """Please provide the POS tags for each word in the input sentence. The input will be a list of words in
the sentence. The output format should be a list of tuples, where each tuple consists of a word from the input text the sentence. The output format should be a list of tuples, where each tuple consists of a word from the input text
......
...@@ -24,3 +24,6 @@ journal = {Transactions of the Association of Computational Linguistics}} ...@@ -24,3 +24,6 @@ journal = {Transactions of the Association of Computational Linguistics}}
### Tasks ### Tasks
* `nq_open` * `nq_open`
### Changelog
* 2025-07-21: Added `multi_target` to `exact_match`. Scores should not change.
task: nq_open task: nq_open
dataset_path: nq_open dataset_path: google-research-datasets/nq_open
output_type: generate_until output_type: generate_until
training_split: train training_split: train
validation_split: validation validation_split: validation
description: "Answer these questions:\n\n" description: "Answer these questions:\n\n"
doc_to_text: "Q: {{question}}?\nA:" doc_to_text: "Q: {{question}}?\nA:"
doc_to_target: "{{answer}}" # TODO: should be multi-target doc_to_target: "{{answer}}"
fewshot_delimiter: "\n" fewshot_delimiter: "\n"
generation_kwargs: generation_kwargs:
until: until:
...@@ -28,5 +28,6 @@ metric_list: ...@@ -28,5 +28,6 @@ metric_list:
ignore_punctuation: true ignore_punctuation: true
regexes_to_ignore: regexes_to_ignore:
- "\\b(?:The |the |An |A |The |a |an )" - "\\b(?:The |the |An |A |The |a |an )"
multi_target: true
metadata: metadata:
version: 4.0 version: 4.0
...@@ -2,9 +2,9 @@ import re ...@@ -2,9 +2,9 @@ import re
from abc import abstractmethod from abc import abstractmethod
from functools import reduce from functools import reduce
import datasets
import numpy as np import numpy as np
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
from datasets import Dataset
from evaluate import load from evaluate import load
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -135,26 +135,10 @@ class _SCROLLSTask(ConfigurableTask): ...@@ -135,26 +135,10 @@ class _SCROLLSTask(ConfigurableTask):
return False return False
def training_docs(self): def training_docs(self):
processed_docs = list(map(self._process_doc, self.dataset["train"])) return self.dataset["train"].map(self._process_doc)
# Flatten the list of lists since _process_doc returns a list of one element.
processed_docs = [item for sublist in processed_docs for item in sublist]
processed_dict = {
key: [d[key] for d in processed_docs] for key in processed_docs[0]
}
return Dataset.from_dict(processed_dict)
def validation_docs(self): def validation_docs(self):
processed_docs = list(map(self._process_doc, self.dataset["validation"])) return self.dataset["validation"].map(self._process_doc)
# Flatten the list of lists since _process_doc returns a list of one element.
processed_docs = [item for sublist in processed_docs for item in sublist]
processed_dict = {
key: [d[key] for d in processed_docs] for key in processed_docs[0]
}
return Dataset.from_dict(processed_dict)
def should_decontaminate(self): def should_decontaminate(self):
return True return True
...@@ -163,8 +147,9 @@ class _SCROLLSTask(ConfigurableTask): ...@@ -163,8 +147,9 @@ class _SCROLLSTask(ConfigurableTask):
return doc["input"] return doc["input"]
def download(self, *args, **kwargs): def download(self, *args, **kwargs):
super().download(*args, **kwargs) self.dataset: datasets.DatasetDict = datasets.load_dataset(
del self.dataset["test"] self.DATASET_PATH, self.DATASET_NAME, splits=["train", "validation"]
)
for split in self.dataset: for split in self.dataset:
self.dataset[split] = _drop_duplicates_in_input(self.dataset[split]) self.dataset[split] = _drop_duplicates_in_input(self.dataset[split])
if self.PRUNE_TOKENIZERS is not None: if self.PRUNE_TOKENIZERS is not None:
...@@ -173,23 +158,26 @@ class _SCROLLSTask(ConfigurableTask): ...@@ -173,23 +158,26 @@ class _SCROLLSTask(ConfigurableTask):
def _get_prune_text(self, sample): def _get_prune_text(self, sample):
return self.doc_to_text(self._process_doc(sample)[0]) return self.doc_to_text(self._process_doc(sample)[0])
def prune(self): def prune(self, **kwargs):
"""Create a pruned version of a SCROLLS task dataset containing only inputs """Create a pruned version of a SCROLLS task dataset containing only inputs
that are less than `max_tokens` when tokenized by each tokenizer that are less than `max_tokens` when tokenized by each tokenizer
""" """
toks = [kwargs.get("tokenizer", kwargs.get("pretrained"))]
tokenizers = [ if self.PRUNE_TOKENIZERS is not None:
AutoTokenizer.from_pretrained(tokenizer) toks.extend(self.PRUNE_TOKENIZERS)
for tokenizer in self.PRUNE_TOKENIZERS max_length = self.PRUNE_MAX_TOKENS or kwargs.get("max_length")
] tokenizers = [AutoTokenizer.from_pretrained(tokenizer) for tokenizer in toks]
cache = {} cache = {}
def _filter(sample): def _filter(sample):
text = self._get_prune_text(sample) text = self._get_prune_text(sample)
cached = cache.get(text, None) cached = cache.get(text)
if cached is None: if cached is None:
for tokenizer in tokenizers: for tokenizer in tokenizers:
if len(tokenizer(text).input_ids) > self.PRUNE_MAX_TOKENS: if (
max_length is not None
and len(tokenizer(text).input_ids) > max_length
):
cache[text] = False cache[text] = False
return False return False
cache[text] = True cache[text] = True
...@@ -206,7 +194,7 @@ class _SCROLLSTask(ConfigurableTask): ...@@ -206,7 +194,7 @@ class _SCROLLSTask(ConfigurableTask):
return f"{doc['text']}\n\nQuestion: {doc['question']}\nAnswer:" return f"{doc['text']}\n\nQuestion: {doc['question']}\nAnswer:"
def higher_is_better(self): def higher_is_better(self):
return {x: True for x in self._scrolls_metrics().keys()} return {x: True for x in self._scrolls_metrics()}
@abstractmethod @abstractmethod
def _scrolls_metrics(self): def _scrolls_metrics(self):
...@@ -263,9 +251,9 @@ class _SCROLLSMultipleChoiceTask(_SCROLLSTask): ...@@ -263,9 +251,9 @@ class _SCROLLSMultipleChoiceTask(_SCROLLSTask):
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)) arguments=(ctx, f" {choice}")
if not apply_chat_template if not apply_chat_template
else (ctx, "{}".format(choice)), else (ctx, f"{choice}"),
idx=i, idx=i,
**kwargs, **kwargs,
) )
......
...@@ -49,3 +49,6 @@ If other tasks on this dataset are already supported: ...@@ -49,3 +49,6 @@ If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted? * [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? * [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? * [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
### Changelog
* 2025-07-21: Added `multi_target` to `exact_match`. Scores should not change.
task: triviaqa task: triviaqa
dataset_path: trivia_qa dataset_path: mandarjoshi/trivia_qa
dataset_name: rc.nocontext dataset_name: rc.nocontext
output_type: generate_until output_type: generate_until
training_split: train training_split: train
...@@ -27,5 +27,6 @@ metric_list: ...@@ -27,5 +27,6 @@ metric_list:
higher_is_better: true higher_is_better: true
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
multi_target: true
metadata: metadata:
version: 3.0 version: 3.0
from __future__ import annotations
import collections import collections
import fnmatch import fnmatch
import functools import functools
...@@ -8,14 +10,16 @@ import json ...@@ -8,14 +10,16 @@ import json
import logging import logging
import os import os
import re import re
from collections.abc import Generator
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from functools import lru_cache, partial, wraps
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
import numpy as np import numpy as np
import yaml import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined, Template
SPACING = " " * 47 SPACING = " " * 47
...@@ -24,8 +28,6 @@ HIGHER_IS_BETTER_SYMBOLS = { ...@@ -24,8 +28,6 @@ HIGHER_IS_BETTER_SYMBOLS = {
True: "↑", True: "↑",
False: "↓", False: "↓",
} }
def wrap_text(string: str, width: int = 140, **kwargs) -> Optional[str]: def wrap_text(string: str, width: int = 140, **kwargs) -> Optional[str]:
""" """
Wraps the given string to the specified width. Wraps the given string to the specified width.
...@@ -43,8 +45,76 @@ def wrap_text(string: str, width: int = 140, **kwargs) -> Optional[str]: ...@@ -43,8 +45,76 @@ def wrap_text(string: str, width: int = 140, **kwargs) -> Optional[str]:
) )
def setup_logging(verbosity=logging.INFO):
# Configure the root logger def get_logger(level: Optional[str] = None) -> logging.Logger:
"""
Get a logger with a stream handler that captures all lm_eval logs.
Args:
level (Optional[str]): The logging level.
Example:
>>> logger = get_logger("INFO")
>>> logger.info("Log this")
INFO:lm_eval:Log this!
Returns:
logging.Logger: The logger.
"""
logger = logging.getLogger("lm_eval")
if not logger.hasHandlers():
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.INFO)
if level is not None:
level = getattr(logging, level.upper())
logger.setLevel(level)
return logger
def setup_logging(verbosity=logging.INFO, suppress_third_party=True):
"""
Configure logging for the lm_eval CLI application.
WARNING: This function is intended for CLI use only. Library users should
use get_logger() instead to avoid interfering with their application's
logging configuration.
Args:
verbosity: Log level (int) or string name. Can be overridden by LOGLEVEL env var.
suppress_third_party: Whether to suppress verbose third-party library logs.
Returns:
logging.Logger: The configured lm_eval logger instance.
"""
# Validate verbosity parameter
if isinstance(verbosity, str):
level_map = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
verbosity = level_map.get(verbosity.upper(), logging.INFO)
elif not isinstance(verbosity, int):
verbosity = logging.INFO
# Get log level from environment or use default
if log_level_env := os.environ.get("LOGLEVEL", None):
level_map = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
log_level = level_map.get(log_level_env.upper(), verbosity)
else:
log_level = verbosity
# Get the lm_eval logger directly
logger = logging.getLogger("lm_eval")
# Configure custom formatter
class CustomFormatter(logging.Formatter): class CustomFormatter(logging.Formatter):
def format(self, record): def format(self, record):
if record.name.startswith("lm_eval."): if record.name.startswith("lm_eval."):
...@@ -56,32 +126,27 @@ def setup_logging(verbosity=logging.INFO): ...@@ -56,32 +126,27 @@ def setup_logging(verbosity=logging.INFO):
datefmt="%Y-%m-%d:%H:%M:%S", datefmt="%Y-%m-%d:%H:%M:%S",
) )
log_level = os.environ.get("LOGLEVEL", verbosity) or verbosity # Check if handler already exists to prevent duplicates
has_stream_handler = any(
level_map = { isinstance(h, logging.StreamHandler) for h in logger.handlers
"DEBUG": logging.DEBUG, )
"INFO": logging.INFO, if not has_stream_handler:
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
log_level = level_map.get(str(log_level).upper(), logging.INFO)
if not logging.root.handlers:
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.addHandler(handler)
# For CLI use, we disable propagation to avoid duplicate messages
logger.propagate = False
root_logger = logging.getLogger() # Set the logger level
root_logger.addHandler(handler) logger.setLevel(log_level)
root_logger.setLevel(log_level)
if log_level == logging.DEBUG: # Optionally suppress verbose third-party library logs
third_party_loggers = ["urllib3", "filelock", "fsspec"] if suppress_third_party and log_level == logging.DEBUG:
for logger_name in third_party_loggers: third_party_loggers = ["urllib3", "filelock", "fsspec"]
logging.getLogger(logger_name).setLevel(logging.INFO) for logger_name in third_party_loggers:
else: logging.getLogger(logger_name).setLevel(logging.INFO)
logging.getLogger().setLevel(log_level)
return logger
def hash_string(string: str) -> str: def hash_string(string: str) -> str:
...@@ -108,7 +173,7 @@ def escaped_split(text, sep_char, maxsplit=-1): ...@@ -108,7 +173,7 @@ def escaped_split(text, sep_char, maxsplit=-1):
return text return text
maxsplit = max(0, maxsplit) maxsplit = max(0, maxsplit)
return re.split(r"(?<!\\)" + sep_char, text, maxsplit) return re.split(r"(?<!\\)" + sep_char, text, maxsplit=maxsplit)
def handle_arg_string(arg): def handle_arg_string(arg):
...@@ -125,7 +190,7 @@ def handle_arg_string(arg): ...@@ -125,7 +190,7 @@ def handle_arg_string(arg):
def handle_non_serializable(o): def handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32): if isinstance(o, np.integer):
return int(o) return int(o)
elif isinstance(o, set): elif isinstance(o, set):
return list(o) return list(o)
...@@ -145,7 +210,7 @@ def sanitize_list(sub): ...@@ -145,7 +210,7 @@ def sanitize_list(sub):
return str(sub) return str(sub)
def simple_parse_args_string(args_string: Optional[str]) -> dict: def simple_parse_args_string(args_string: str | None) -> dict:
""" """
Parses something like Parses something like
args1=val1,arg2=val2 args1=val1,arg2=val2
...@@ -180,7 +245,7 @@ def group(arr, fn): ...@@ -180,7 +245,7 @@ def group(arr, fn):
# Returns a list containing all values of the source_list that # Returns a list containing all values of the source_list that
# match at least one of the patterns # match at least one of the patterns
def pattern_match(patterns, source_list): def pattern_match(patterns: list[str], source_list: list[str]) -> list[str]:
if isinstance(patterns, str): if isinstance(patterns, str):
patterns = [patterns] patterns = [patterns]
...@@ -197,7 +262,7 @@ def softmax(x) -> np.ndarray: ...@@ -197,7 +262,7 @@ def softmax(x) -> np.ndarray:
return e_x / e_x.sum() return e_x / e_x.sum()
def general_detokenize(string) -> str: def general_detokenize(string: str) -> str:
string = string.replace(" n't", "n't") string = string.replace(" n't", "n't")
string = string.replace(" )", ")") string = string.replace(" )", ")")
string = string.replace("( ", "(") string = string.replace("( ", "(")
...@@ -225,7 +290,7 @@ def sanitize_model_name(model_name: str) -> str: ...@@ -225,7 +290,7 @@ def sanitize_model_name(model_name: str) -> str:
""" """
Given the model name, returns a sanitized version of it. Given the model name, returns a sanitized version of it.
""" """
return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name) return re.sub(r"[\"<>:/|\\?*\[\]]+", "__", model_name)
def sanitize_task_name(task_name: str) -> str: def sanitize_task_name(task_name: str) -> str:
...@@ -235,21 +300,21 @@ def sanitize_task_name(task_name: str) -> str: ...@@ -235,21 +300,21 @@ def sanitize_task_name(task_name: str) -> str:
return re.sub(r"\W", "_", task_name) return re.sub(r"\W", "_", task_name)
def get_latest_filename(filenames: List[str]) -> str: def get_latest_filename(filenames: list[str]) -> str:
""" """
Given a list of filenames, returns the filename with the latest datetime. Given a list of filenames, returns the filename with the latest datetime.
""" """
return max(filenames, key=lambda f: get_file_datetime(f)) return max(filenames, key=lambda f: get_file_datetime(f))
def get_results_filenames(filenames: List[str]) -> List[str]: def get_results_filenames(filenames: list[str]) -> list[str]:
""" """
Extracts filenames that correspond to aggregated results. Extracts filenames that correspond to aggregated results.
""" """
return [f for f in filenames if "/results_" in f and ".json" in f] return [f for f in filenames if "/results_" in f and ".json" in f]
def get_sample_results_filenames(filenames: List[str]) -> List[str]: def get_sample_results_filenames(filenames: list[str]) -> list[str]:
""" """
Extracts filenames that correspond to sample results. Extracts filenames that correspond to sample results.
""" """
...@@ -257,8 +322,8 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]: ...@@ -257,8 +322,8 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
def get_rolling_token_windows( def get_rolling_token_windows(
token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int token_list: list[int], prefix_token: int, max_seq_len: int, context_len: int
) -> Generator[Tuple[List[int], List[int]], None, None]: ) -> Generator[tuple[list[int], list[int]], None, None]:
""" """
- context_len allows for a rolling window context, allowing each prediction window to potentially - context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context condition on some context
...@@ -300,8 +365,8 @@ def get_rolling_token_windows( ...@@ -300,8 +365,8 @@ def get_rolling_token_windows(
def make_disjoint_window( def make_disjoint_window(
pair: Tuple[List[int], List[int]], pair: tuple[list[int], list[int]],
) -> Tuple[List[int], List[int]]: ) -> tuple[list[int], list[int]]:
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation""" """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a, b = pair a, b = pair
return a[: len(a) - (len(b) - 1)], b return a[: len(a) - (len(b) - 1)], b
...@@ -320,7 +385,7 @@ class EnhancedJSONEncoder(json.JSONEncoder): ...@@ -320,7 +385,7 @@ class EnhancedJSONEncoder(json.JSONEncoder):
class Reorderer: class Reorderer:
def __init__(self, arr: List[Any], fn: Callable) -> None: def __init__(self, arr: list[Any], fn: Callable) -> None:
"""Reorder an array according to some function """Reorder an array according to some function
Args: Args:
...@@ -405,7 +470,8 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False) ...@@ -405,7 +470,8 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
dic = result_dict[column][k] dic = result_dict[column][k]
version = result_dict["versions"].get(k, " N/A") version = result_dict["versions"].get(k, " N/A")
n = str(result_dict.get("n-shot", " ").get(k, " ")) n = str(result_dict.get("n-shot", " ").get(k, " "))
higher_is_better = result_dict.get("higher_is_better", {}).get(k, {}) # TODO: fix this
# higher_is_better = result_dict.get("higher_is_better", {}).get(k, {})
if "alias" in dic: if "alias" in dic:
k = dic.pop("alias") k = dic.pop("alias")
...@@ -418,13 +484,15 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False) ...@@ -418,13 +484,15 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
if m.endswith("_stderr"): if m.endswith("_stderr"):
continue continue
hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "") # hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "")
# TODO: fix
hib = "↑"
v = "%.4f" % v if isinstance(v, float) else v v = f"{v:.4f}" if isinstance(v, float) else v
if m + "_stderr" + "," + f in dic: if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f] se = dic[m + "_stderr" + "," + f]
se = " N/A" if se == "N/A" else "%.4f" % se se = " N/A" if se == "N/A" else f"{se:.4f}"
values.append([k, version, f, n, m, hib, v, "±", se]) values.append([k, version, f, n, m, hib, v, "±", se])
else: else:
values.append([k, version, f, n, m, hib, v, "", ""]) values.append([k, version, f, n, m, hib, v, "", ""])
...@@ -445,7 +513,8 @@ def positional_deprecated(fn): ...@@ -445,7 +513,8 @@ def positional_deprecated(fn):
wrapped function, `fn`. wrapped function, `fn`.
""" """
@functools.wraps(fn) wraps(fn)
def _wrapper(*args, **kwargs): def _wrapper(*args, **kwargs):
if len(args) != 1 if inspect.ismethod(fn) else 0: if len(args) != 1 if inspect.ismethod(fn) else 0:
print( print(
...@@ -484,14 +553,16 @@ def import_function(loader: yaml.Loader, node, yaml_path: Path): ...@@ -484,14 +553,16 @@ def import_function(loader: yaml.Loader, node, yaml_path: Path):
return function return function
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"): def load_yaml_config(
yaml_path: str | None = None, yaml_config=None, yaml_dir=None, mode="full"
):
if mode == "simple": if mode == "simple":
constructor_fn = ignore_constructor constructor_fn = ignore_constructor
elif mode == "full": elif mode == "full":
if yaml_path is None: if yaml_path is None:
raise ValueError("yaml_path must be provided if mode is 'full'.") raise ValueError("yaml_path must be provided if mode is 'full'.")
# Attach yaml_path to the import function so that it can be used later # Attach yaml_path to the import function so that it can be used later
constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path)) constructor_fn = partial(import_function, yaml_path=Path(yaml_path))
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
# Add the import_function constructor to the YAML loader # Add the import_function constructor to the YAML loader
...@@ -540,17 +611,28 @@ def regex_replace(string, pattern, repl, count: int = 0): ...@@ -540,17 +611,28 @@ def regex_replace(string, pattern, repl, count: int = 0):
env = Environment( env = Environment(
loader=BaseLoader, undefined=StrictUndefined, keep_trailing_newline=True loader=BaseLoader(), undefined=StrictUndefined, keep_trailing_newline=True
) )
env.filters["regex_replace"] = regex_replace env.filters["regex_replace"] = regex_replace
@lru_cache(maxsize=128)
def _compile(raw: str) -> Template:
return env.from_string(raw)
def apply_template(template: str, doc: dict) -> str: def apply_template(template: str, doc: dict) -> str:
rtemplate = env.from_string(template) rtemplate = _compile(template)
return rtemplate.render(**doc) return rtemplate.render(**doc)
def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None): def create_iterator(
raw_iterator: collections.Iterator,
*,
rank: int = 0,
world_size: int = 1,
limit: int | None = None,
) -> islice:
""" """
Method for creating a (potentially) sliced and limited Method for creating a (potentially) sliced and limited
iterator from a raw document iterator. Used for splitting data iterator from a raw document iterator. Used for splitting data
......
...@@ -19,26 +19,20 @@ classifiers = [ ...@@ -19,26 +19,20 @@ classifiers = [
requires-python = ">=3.9" requires-python = ">=3.9"
license = { "text" = "MIT" } license = { "text" = "MIT" }
dependencies = [ dependencies = [
"accelerate>=0.26.0", "accelerate>=0.26.0",
"evaluate", "datasets>=2.16.0,<4.0",
"datasets>=2.16.0,<4.0", "evaluate>=0.4.0",
"evaluate>=0.4.0", "peft>=0.2.0",
"jsonlines", "pytablewriter",
"numexpr", "rouge-score>=0.0.4",
"peft>=0.2.0", "sacrebleu>=1.5.0",
"pybind11>=2.6.2", "scikit-learn>=0.24.1",
"pytablewriter", "sqlitedict",
"rouge-score>=0.0.4", "torch>=1.8",
"sacrebleu>=1.5.0", "transformers>=4.1",
"scikit-learn>=0.24.1", "dill",
"sqlitedict", "word2number",
"torch>=1.8", "more_itertools"
"tqdm-multiprocess",
"transformers>=4.1",
"zstandard",
"dill",
"word2number",
"more_itertools",
] ]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
...@@ -68,7 +62,7 @@ ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"] ...@@ -68,7 +62,7 @@ ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"]
ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"] ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"]
ipex = ["optimum"] ipex = ["optimum"]
japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"] japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"]
longbench=["jieba", "fuzzywuzzy", "rouge"] longbench = ["jieba", "fuzzywuzzy", "rouge"]
libra=["pymorphy2"] libra=["pymorphy2"]
mamba = ["mamba_ssm", "causal-conv1d==1.0.2", "torch"] mamba = ["mamba_ssm", "causal-conv1d==1.0.2", "torch"]
math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"] math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"]
...@@ -87,17 +81,30 @@ vllm = ["vllm>=0.4.2"] ...@@ -87,17 +81,30 @@ vllm = ["vllm>=0.4.2"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"] wandb = ["wandb>=0.16.3", "pandas", "numpy"]
zeno = ["pandas", "zeno-client"] zeno = ["pandas", "zeno-client"]
tasks = [ tasks = [
"lm_eval[acpbench]", "lm_eval[acpbench]",
"lm_eval[discrim_eval]", "lm_eval[discrim_eval]",
"lm_eval[ifeval]", "lm_eval[ifeval]",
"lm_eval[japanese_leaderboard]", "lm_eval[japanese_leaderboard]",
"lm_eval[longbench]", "lm_eval[longbench]",
"lm_eval[libra]", "lm_eval[libra]",
"lm_eval[mamba]", "lm_eval[mamba]",
"lm_eval[math]", "lm_eval[math]",
"lm_eval[multilingual]", "lm_eval[multilingual]",
"lm_eval[ruler]", "lm_eval[ruler]"
] ]
testing = ["pytest", "pytest-cov", "pytest-xdist"]
unitxt = ["unitxt==1.22.0"]
vllm = ["vllm>=0.4.2"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"]
zeno = ["pandas", "zeno-client"]
[project.scripts]
lm-eval = "lm_eval.__main__:cli_evaluate"
lm_eval = "lm_eval.__main__:cli_evaluate"
[project.urls]
Homepage = "https://github.com/EleutherAI/lm-evaluation-harness"
Repository = "https://github.com/EleutherAI/lm-evaluation-harness"
[tool.pymarkdown] [tool.pymarkdown]
plugins.md013.enabled = false # line-length plugins.md013.enabled = false # line-length
...@@ -107,18 +114,23 @@ plugins.md028.enabled = false # no-blanks-blockquote ...@@ -107,18 +114,23 @@ plugins.md028.enabled = false # no-blanks-blockquote
plugins.md029.allow_extended_start_values = true # ol-prefix plugins.md029.allow_extended_start_values = true # ol-prefix
plugins.md034.enabled = false # no-bare-urls plugins.md034.enabled = false # no-bare-urls
[tool.ruff.lint] [tool.ruff]
extend-select = ["I", "W605", "UP"] target-version = "py39"
lint.extend-select = ["I", "UP", "E", "C419", "F", "B", "SIM", "RUF034", "W605", "FURB"]
lint.fixable = ["I001", "F401", "UP"]
lint.ignore = ["E402", "E731", "E501", "E111", "E114", "E117", "E741"]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["F401", "F402", "F403", "F405"]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
lines-after-imports = 2 combine-as-imports = true
known-first-party = ["lm_eval"] known-first-party = ["lm_eval"]
lines-after-imports = 2
[tool.ruff.lint.extend-per-file-ignores] # required to include yaml files in pip installation
"__init__.py" = ["F401","F402","F403","F405"] [tool.setuptools.package-data]
"utils.py" = ["F401"] lm_eval = ["**/*.yaml", "tasks/**/*"]
[dependency-groups] [tool.setuptools.packages.find]
dev = [ include = ["lm_eval*"]
"api","dev","sentencepiece"
]
# Language Model Evaluation Harness Configuration File
#
# This YAML configuration file allows you to specify evaluation parameters
# instead of passing them as command-line arguments.
#
# Usage:
# $ lm_eval --config templates/example_ci_config.yaml
#
# You can override any values in this config with further command-line arguments:
# $ lm_eval --config templates/example_ci_config.yaml --model_args pretrained=gpt2 --tasks mmlu
#
# For expected types and values, refer to EvaluatorConfig in lm_eval/config/evaluate_config.py
# All parameters are optional and have the same meaning as their CLI counterparts.
model: hf
model_args:
pretrained: EleutherAI/pythia-14m
dtype: float16
tasks:
- hellaswag
- arc_easy
batch_size: 1
trust_remote_code: true
log_samples: true
output_path: ./test
gen_kwargs:
do_sample: true
temperature: 0.7
stop: ["\n", "<|endoftext|>"]
samples:
hellaswag: [1,2,3,4,5,6,7,8,9,10]
arc_easy: [10,20,30,40,50,60,70,80,90,100]
metadata:
name: Example CI Config
description: This is an example configuration file for testing purposes.
import argparse
import sys
from unittest.mock import MagicMock, patch
import pytest
from lm_eval._cli.harness import HarnessCLI
from lm_eval._cli.ls import List
from lm_eval._cli.run import Run
from lm_eval._cli.utils import (
_int_or_none_list_arg_type,
check_argument_types,
request_caching_arg_to_dict,
try_parse_json,
)
from lm_eval._cli.validate import Validate
class TestHarnessCLI:
"""Test the main HarnessCLI class."""
def test_harness_cli_init(self):
"""Test HarnessCLI initialization."""
cli = HarnessCLI()
assert cli._parser is not None
assert cli._subparsers is not None
def test_harness_cli_has_subcommands(self):
"""Test that HarnessCLI has all expected subcommands."""
cli = HarnessCLI()
subcommands = cli._subparsers.choices
assert "run" in subcommands
assert "ls" in subcommands
assert "validate" in subcommands
def test_harness_cli_backward_compatibility(self):
"""Test backward compatibility: inserting 'run' when no subcommand is provided."""
cli = HarnessCLI()
test_args = ["lm-eval", "--model", "hf", "--tasks", "hellaswag"]
with patch.object(sys, "argv", test_args):
args = cli.parse_args()
assert args.command == "run"
assert args.model == "hf"
assert args.tasks == "hellaswag"
def test_harness_cli_help_default(self):
"""Test that help is printed when no arguments are provided."""
cli = HarnessCLI()
with patch.object(sys, "argv", ["lm-eval"]):
args = cli.parse_args()
# The func is a lambda that calls print_help
# Let's test it calls the help function correctly
with patch.object(cli._parser, "print_help") as mock_help:
args.func(args)
mock_help.assert_called_once()
def test_harness_cli_run_help_only(self):
"""Test that 'lm-eval run' shows help."""
cli = HarnessCLI()
with patch.object(sys, "argv", ["lm-eval", "run"]):
with pytest.raises(SystemExit):
cli.parse_args()
class TestListCommand:
"""Test the List subcommand."""
def test_list_command_creation(self):
"""Test List command creation."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
list_cmd = List.create(subparsers)
assert isinstance(list_cmd, List)
def test_list_command_arguments(self):
"""Test List command arguments."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
List.create(subparsers)
# Test valid arguments
args = parser.parse_args(["ls", "tasks"])
assert args.what == "tasks"
assert args.include_path is None
args = parser.parse_args(["ls", "groups", "--include_path", "/path/to/tasks"])
assert args.what == "groups"
assert args.include_path == "/path/to/tasks"
def test_list_command_choices(self):
"""Test List command only accepts valid choices."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
List.create(subparsers)
# Valid choices should work
for choice in ["tasks", "groups", "subtasks", "tags"]:
args = parser.parse_args(["ls", choice])
assert args.what == choice
# Invalid choice should fail
with pytest.raises(SystemExit):
parser.parse_args(["ls", "invalid"])
@patch("lm_eval.tasks.TaskManager")
def test_list_command_execute_tasks(self, mock_task_manager):
"""Test List command execution for tasks."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
list_cmd = List.create(subparsers)
mock_tm_instance = MagicMock()
mock_tm_instance.list_all_tasks.return_value = "task1\ntask2\ntask3"
mock_task_manager.return_value = mock_tm_instance
args = parser.parse_args(["ls", "tasks"])
with patch("builtins.print") as mock_print:
list_cmd._execute(args)
mock_print.assert_called_once_with("task1\ntask2\ntask3")
mock_tm_instance.list_all_tasks.assert_called_once_with()
@patch("lm_eval.tasks.TaskManager")
def test_list_command_execute_groups(self, mock_task_manager):
"""Test List command execution for groups."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
list_cmd = List.create(subparsers)
mock_tm_instance = MagicMock()
mock_tm_instance.list_all_tasks.return_value = "group1\ngroup2"
mock_task_manager.return_value = mock_tm_instance
args = parser.parse_args(["ls", "groups"])
with patch("builtins.print") as mock_print:
list_cmd._execute(args)
mock_print.assert_called_once_with("group1\ngroup2")
mock_tm_instance.list_all_tasks.assert_called_once_with(
list_subtasks=False, list_tags=False
)
class TestRunCommand:
"""Test the Run subcommand."""
def test_run_command_creation(self):
"""Test Run command creation."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
run_cmd = Run.create(subparsers)
assert isinstance(run_cmd, Run)
def test_run_command_basic_arguments(self):
"""Test Run command basic arguments."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Run.create(subparsers)
args = parser.parse_args(
["run", "--model", "hf", "--tasks", "hellaswag,arc_easy"]
)
assert args.model == "hf"
assert args.tasks == "hellaswag,arc_easy"
def test_run_command_model_args(self):
"""Test Run command model arguments parsing."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Run.create(subparsers)
# Test key=value format
args = parser.parse_args(["run", "--model_args", "pretrained=gpt2,device=cuda"])
assert args.model_args == "pretrained=gpt2,device=cuda"
# Test JSON format
args = parser.parse_args(
["run", "--model_args", '{"pretrained": "gpt2", "device": "cuda"}']
)
assert args.model_args == {"pretrained": "gpt2", "device": "cuda"}
def test_run_command_batch_size(self):
"""Test Run command batch size arguments."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Run.create(subparsers)
# Test integer batch size
args = parser.parse_args(["run", "--batch_size", "32"])
assert args.batch_size == "32"
# Test auto batch size
args = parser.parse_args(["run", "--batch_size", "auto"])
assert args.batch_size == "auto"
# Test auto with repetitions
args = parser.parse_args(["run", "--batch_size", "auto:5"])
assert args.batch_size == "auto:5"
def test_run_command_seed_parsing(self):
"""Test Run command seed parsing."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Run.create(subparsers)
# Test single seed
args = parser.parse_args(["run", "--seed", "42"])
assert args.seed == [42, 42, 42, 42]
# Test multiple seeds
args = parser.parse_args(["run", "--seed", "0,1234,5678,9999"])
assert args.seed == [0, 1234, 5678, 9999]
# Test with None values
args = parser.parse_args(["run", "--seed", "0,None,1234,None"])
assert args.seed == [0, None, 1234, None]
@patch("lm_eval.simple_evaluate")
@patch("lm_eval.config.evaluate_config.EvaluatorConfig")
@patch("lm_eval.loggers.EvaluationTracker")
@patch("lm_eval.utils.make_table")
def test_run_command_execute_basic(
self, mock_make_table, mock_tracker, mock_config, mock_simple_evaluate
):
"""Test Run command basic execution."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
run_cmd = Run.create(subparsers)
# Mock configuration
mock_cfg_instance = MagicMock()
mock_cfg_instance.wandb_args = None
mock_cfg_instance.output_path = None
mock_cfg_instance.hf_hub_log_args = {}
mock_cfg_instance.include_path = None
mock_cfg_instance.tasks = ["hellaswag"]
mock_cfg_instance.model = "hf"
mock_cfg_instance.model_args = {"pretrained": "gpt2"}
mock_cfg_instance.gen_kwargs = {}
mock_cfg_instance.limit = None
mock_cfg_instance.num_fewshot = 0
mock_cfg_instance.batch_size = 1
mock_cfg_instance.log_samples = False
mock_cfg_instance.process_tasks.return_value = MagicMock()
mock_config.from_cli.return_value = mock_cfg_instance
# Mock evaluation results
mock_simple_evaluate.return_value = {
"results": {"hellaswag": {"acc": 0.75}},
"config": {"batch_sizes": [1]},
"configs": {"hellaswag": {}},
"versions": {"hellaswag": "1.0"},
"n-shot": {"hellaswag": 0},
}
# Mock make_table to avoid complex table rendering
mock_make_table.return_value = (
"| Task | Result |\n|------|--------|\n| hellaswag | 0.75 |"
)
args = parser.parse_args(["run", "--model", "hf", "--tasks", "hellaswag"])
with patch("builtins.print"):
run_cmd._execute(args)
mock_config.from_cli.assert_called_once()
mock_simple_evaluate.assert_called_once()
mock_make_table.assert_called_once()
class TestValidateCommand:
"""Test the Validate subcommand."""
def test_validate_command_creation(self):
"""Test Validate command creation."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
validate_cmd = Validate.create(subparsers)
assert isinstance(validate_cmd, Validate)
def test_validate_command_arguments(self):
"""Test Validate command arguments."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Validate.create(subparsers)
args = parser.parse_args(["validate", "--tasks", "hellaswag,arc_easy"])
assert args.tasks == "hellaswag,arc_easy"
assert args.include_path is None
args = parser.parse_args(
["validate", "--tasks", "custom_task", "--include_path", "/path/to/tasks"]
)
assert args.tasks == "custom_task"
assert args.include_path == "/path/to/tasks"
def test_validate_command_requires_tasks(self):
"""Test Validate command requires tasks argument."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Validate.create(subparsers)
with pytest.raises(SystemExit):
parser.parse_args(["validate"])
@patch("lm_eval.tasks.TaskManager")
def test_validate_command_execute_success(self, mock_task_manager):
"""Test Validate command execution with valid tasks."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
validate_cmd = Validate.create(subparsers)
mock_tm_instance = MagicMock()
mock_tm_instance.match_tasks.return_value = ["hellaswag", "arc_easy"]
mock_task_manager.return_value = mock_tm_instance
args = parser.parse_args(["validate", "--tasks", "hellaswag,arc_easy"])
with patch("builtins.print") as mock_print:
validate_cmd._execute(args)
mock_print.assert_any_call("Validating tasks: ['hellaswag', 'arc_easy']")
mock_print.assert_any_call("All tasks found and valid")
@patch("lm_eval.tasks.TaskManager")
def test_validate_command_execute_missing_tasks(self, mock_task_manager):
"""Test Validate command execution with missing tasks."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
validate_cmd = Validate.create(subparsers)
mock_tm_instance = MagicMock()
mock_tm_instance.match_tasks.return_value = ["hellaswag"]
mock_task_manager.return_value = mock_tm_instance
args = parser.parse_args(["validate", "--tasks", "hellaswag,nonexistent"])
with patch("builtins.print") as mock_print:
with pytest.raises(SystemExit) as exc_info:
validate_cmd._execute(args)
assert exc_info.value.code == 1
mock_print.assert_any_call("Tasks not found: nonexistent")
class TestCLIUtils:
"""Test CLI utility functions."""
def test_try_parse_json_with_json_string(self):
"""Test try_parse_json with valid JSON string."""
result = try_parse_json('{"key": "value", "num": 42}')
assert result == {"key": "value", "num": 42}
def test_try_parse_json_with_dict(self):
"""Test try_parse_json with dict input."""
input_dict = {"key": "value"}
result = try_parse_json(input_dict)
assert result is input_dict
def test_try_parse_json_with_none(self):
"""Test try_parse_json with None."""
result = try_parse_json(None)
assert result is None
def test_try_parse_json_with_plain_string(self):
"""Test try_parse_json with plain string."""
result = try_parse_json("key=value,key2=value2")
assert result == "key=value,key2=value2"
def test_try_parse_json_with_invalid_json(self):
"""Test try_parse_json with invalid JSON."""
with pytest.raises(ValueError) as exc_info:
try_parse_json('{key: "value"}') # Invalid JSON (unquoted key)
assert "Invalid JSON" in str(exc_info.value)
assert "double quotes" in str(exc_info.value)
def test_int_or_none_list_single_value(self):
"""Test _int_or_none_list_arg_type with single value."""
result = _int_or_none_list_arg_type(3, 4, "0,1,2,3", "42")
assert result == [42, 42, 42, 42]
def test_int_or_none_list_multiple_values(self):
"""Test _int_or_none_list_arg_type with multiple values."""
result = _int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,20,30,40")
assert result == [10, 20, 30, 40]
def test_int_or_none_list_with_none(self):
"""Test _int_or_none_list_arg_type with None values."""
result = _int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,None,30,None")
assert result == [10, None, 30, None]
def test_int_or_none_list_invalid_value(self):
"""Test _int_or_none_list_arg_type with invalid value."""
with pytest.raises(ValueError):
_int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,invalid,30,40")
def test_int_or_none_list_too_few_values(self):
"""Test _int_or_none_list_arg_type with too few values."""
with pytest.raises(ValueError):
_int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,20")
def test_int_or_none_list_too_many_values(self):
"""Test _int_or_none_list_arg_type with too many values."""
with pytest.raises(ValueError):
_int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,20,30,40,50")
def test_request_caching_arg_to_dict_none(self):
"""Test request_caching_arg_to_dict with None."""
result = request_caching_arg_to_dict(None)
assert result == {}
def test_request_caching_arg_to_dict_true(self):
"""Test request_caching_arg_to_dict with 'true'."""
result = request_caching_arg_to_dict("true")
assert result == {
"cache_requests": True,
"rewrite_requests_cache": False,
"delete_requests_cache": False,
}
def test_request_caching_arg_to_dict_refresh(self):
"""Test request_caching_arg_to_dict with 'refresh'."""
result = request_caching_arg_to_dict("refresh")
assert result == {
"cache_requests": True,
"rewrite_requests_cache": True,
"delete_requests_cache": False,
}
def test_request_caching_arg_to_dict_delete(self):
"""Test request_caching_arg_to_dict with 'delete'."""
result = request_caching_arg_to_dict("delete")
assert result == {
"cache_requests": False,
"rewrite_requests_cache": False,
"delete_requests_cache": True,
}
def test_check_argument_types_raises_on_untyped(self):
"""Test check_argument_types raises error for untyped arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--untyped") # No type specified
with pytest.raises(ValueError) as exc_info:
check_argument_types(parser)
assert "untyped" in str(exc_info.value)
assert "doesn't have a type specified" in str(exc_info.value)
def test_check_argument_types_passes_on_typed(self):
"""Test check_argument_types passes for typed arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--typed", type=str)
# Should not raise
check_argument_types(parser)
def test_check_argument_types_skips_const_actions(self):
"""Test check_argument_types skips const actions."""
parser = argparse.ArgumentParser()
parser.add_argument("--flag", action="store_const", const=True)
# Should not raise
check_argument_types(parser)
import unittest.mock as mock import unittest.mock as mock
from lm_eval.api.metrics import _bootstrap_internal_no_mp, mean from lm_eval.api.metrics import _bootstrap_internal_no_mp, mean
from lm_eval.api.task import ConfigurableTask, TaskConfig from lm_eval.api.task import ConfigurableTask
from lm_eval.config.task import TaskConfig
class MockConfigurableTask(ConfigurableTask): class MockConfigurableTask(ConfigurableTask):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment