Unverified Commit 003e5852 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Fewshot refactor (#3227)



* overhaul `ContextSampler`

* refactor masakhapos

* move multi_target to `exact_match`

* remove doc_to_choice from `boolq-seq2seq`

* remove doc_to_choice in generation process_results

* Remove unused `doc_to_choice` and fix superglue whitespaces

* require multiple_inputs and multiple_targets to be explicitly set in taskconfig

* fix copa; better logging in task init

* fix doc_to_target to return int rather than str (deprecated)

* fix processing regression; recursively parse lists fron template

* remove redundant jinja parsing logic

* remove promptsource

* for multiple_inputs use `doc_to_text: list[str]``

* Refactor `ContextSampler` `fewshot_context`

* fix multiple_input context

* fix `target_delimiter` with `gen_prefix`

* `doc_to_text` is list for multiple_inputs

* Refactor `count_bytes` and `count_words` methods to `@staticmethod`

* make has_*(train/test/validation) to properties

* remove `multi_target` `generate_until`

* `fix doc_to_target/multiple_targets handling add tests

* rename `multi_target` to `multiple_targets`

* evalaute list when multiple targets

* allow doc_to_target to return list

* Remove gen_prefix space and add warning (#3239)

* Remove gen_prefix space and add warning

* fix null gen_prefix bug again

* use git tests

---------
Co-authored-by: default avatarBoaz Ben-Dov <bendboaz@gmail.com>
parent 79a22a11
tag: tag:
- super-glue-lm-eval-v1 - super-glue-lm-eval-v1
task: copa task: copa
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: copa dataset_name: copa
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: !function utils.doc_to_text doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target doc_to_target: label
doc_to_choice: !function utils.doc_to_choice doc_to_choice: ["{{ choice1 }}", "{{ choice2 }}"]
metric_list: metric_list:
- metric: acc - metric: acc
metadata: metadata:
......
tag: tag:
- super-glue-t5-prompt - super-glue-t5-prompt
task: super_glue-copa-t5-prompt task: super_glue-copa-t5-prompt
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: copa dataset_name: copa
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: generate_until output_type: generate_until
doc_to_text: "copa choice1: {{choice1}} choice2: {{choice2}} premise: {{premise}} question: {{question}}" doc_to_text: "copa choice1: {{choice1}} choice2: {{choice2}} premise: {{premise}} question: {{question}}"
doc_to_target: label doc_to_target: "{{ [choice1, choice2][label|int] }}"
doc_to_choice: ['choice1', 'choice2']
generation_kwargs: generation_kwargs:
until: until:
- "</s>" - "</s>"
......
tag: tag:
- super-glue-lm-eval-v1 - super-glue-lm-eval-v1
task: multirc task: multirc
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: multirc dataset_name: multirc
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
......
tag: tag:
- super-glue-t5-prompt - super-glue-t5-prompt
task: super_glue-multirc-t5-prompt task: super_glue-multirc-t5-prompt
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: multirc dataset_name: multirc
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: generate_until output_type: generate_until
doc_to_text: "multirc question: {{question}} answer: {{answer}} paragraph: {{paragraph}}" doc_to_text: "multirc question: {{question}} answer: {{answer}} paragraph: {{paragraph}}|trim"
doc_to_target: label doc_to_target: "{% set group_id = idx.question|string %}{{[group_id+'_False', group_id+'_True'][label]}}"
doc_to_choice: "{% set group_id = idx.question|string %}{{[group_id+'_False', group_id+'_True']}}"
generation_kwargs: generation_kwargs:
until: until:
- "</s>" - "</s>"
......
tag: tag:
- super-glue-lm-eval-v1 - super-glue-lm-eval-v1
task: record task: record
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: record dataset_name: record
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
...@@ -11,6 +11,7 @@ doc_to_target: !function util.doc_to_target ...@@ -11,6 +11,7 @@ doc_to_target: !function util.doc_to_target
doc_to_choice: !function util.doc_to_choice doc_to_choice: !function util.doc_to_choice
process_docs: !function util.process_docs process_docs: !function util.process_docs
process_results: !function util.process_results process_results: !function util.process_results
target_delimiter: ""
metric_list: metric_list:
- metric: f1 - metric: f1
aggregation: mean aggregation: mean
......
tag: tag:
- super-glue-t5-prompt - super-glue-t5-prompt
task: super_glue-record-t5-prompt task: super_glue-record-t5-prompt
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: record dataset_name: record
validation_split: validation validation_split: validation
output_type: generate_until output_type: generate_until
......
...@@ -19,7 +19,7 @@ def format_answer(query, entity): ...@@ -19,7 +19,7 @@ def format_answer(query, entity):
def doc_to_target(doc): def doc_to_target(doc):
# We only output the first correct entity in a doc # We only output the first correct entity in a doc
return format_answer(query=doc["query"], entity=doc["answers"][0]) return doc["entities"].index(doc["answers"][0])
def doc_to_choice(doc): def doc_to_choice(doc):
......
tag: tag:
- super-glue-lm-eval-v1 - super-glue-lm-eval-v1
task: sglue_rte task: sglue_rte
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: rte dataset_name: rte
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}} True or False?\nAnswer:" doc_to_text: "{{premise}}\nQuestion: {{hypothesis}} True or False?\nAnswer:"
doc_to_target: label doc_to_target: label
doc_to_choice: ['True', 'False'] doc_to_choice: ["True", "False"]
metric_list: metric_list:
- metric: acc - metric: acc
metadata: metadata:
......
tag: tag:
- super-glue-t5-prompt - super-glue-t5-prompt
task: super_glue-rte-t5-prompt task: super_glue-rte-t5-prompt
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: rte dataset_name: rte
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: generate_until output_type: generate_until
doc_to_text: "rte hypothesis: {{hypothesis}} premise: {{premise}}" doc_to_text: "rte hypothesis: {{hypothesis}} premise: {{premise}}"
doc_to_target: label doc_to_target: "{{ ['entailment', 'not_entailment'][label|int] }}"
doc_to_choice: ['entailment', 'not_entailment']
generation_kwargs: generation_kwargs:
until: until:
- "</s>" - "</s>"
......
tag: tag:
- super-glue-lm-eval-v1 - super-glue-lm-eval-v1
task: "wic" task: "wic"
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: wic dataset_name: wic
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: "Sentence 1: {{sentence1}}\nSentence 2: {{sentence2}}\nQuestion: Is the word '{{sentence1[start1:end1]}}' used in the same way in the two sentences above?\nAnswer:" doc_to_text: "Sentence 1: {{sentence1}}\nSentence 2: {{sentence2}}\nQuestion: Is the word '{{sentence1[start1:end1]}}' used in the same way in the two sentences above?\nAnswer:"
doc_to_target: label doc_to_target: label
doc_to_choice: ['no', 'yes'] doc_to_choice: ["no", "yes"]
metric_list: metric_list:
- metric: acc - metric: acc
metadata: metadata:
......
tag: tag:
- super-glue-t5-prompt - super-glue-t5-prompt
task: super_glue-wic-t5-prompt task: super_glue-wic-t5-prompt
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: wic dataset_name: wic
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: generate_until output_type: generate_until
doc_to_text: "wic sentence1: {{sentence1}} sentence2: {{sentence2}} word: {{word}}" doc_to_text: "wic sentence1: {{sentence1}} sentence2: {{sentence2}} word: {{word}}"
doc_to_target: label doc_to_target: "{{ ['False', 'True'][label|int] }}"
doc_to_choice: ['False', 'True']
generation_kwargs: generation_kwargs:
until: until:
- "</s>" - "</s>"
......
tag: tag:
- super-glue-lm-eval-v1 - super-glue-lm-eval-v1
task: wsc task: wsc
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: wsc.fixed dataset_name: wsc.fixed
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: !function preprocess_wsc.default_doc_to_text doc_to_text: !function preprocess_wsc.default_doc_to_text
doc_to_target: label doc_to_target: label
doc_to_choice: ['no', 'yes'] doc_to_choice: ["no", "yes"]
metric_list: metric_list:
- metric: acc - metric: acc
metadata: metadata:
......
tag: tag:
- super-glue-t5-prompt - super-glue-t5-prompt
task: super_glue-wsc-t5-prompt task: super_glue-wsc-t5-prompt
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: wsc.fixed dataset_name: wsc.fixed
training_split: train training_split: train
validation_split: validation validation_split: validation
......
...@@ -12,7 +12,6 @@ high quality distant supervision for answering the questions. ...@@ -12,7 +12,6 @@ high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/ Homepage: https://nlp.cs.washington.edu/triviaqa/
### Citation ### Citation
``` ```
...@@ -40,15 +39,18 @@ Homepage: https://nlp.cs.washington.edu/triviaqa/ ...@@ -40,15 +39,18 @@ Homepage: https://nlp.cs.washington.edu/triviaqa/
### Checklist ### Checklist
For adding novel benchmarks/datasets to the library: For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the
reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported: If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted? * [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? * [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? * [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
### Changelog ### Changelog
* 2025-07-21: Added `multi_target` to `exact_match`. Scores should not change.
* 2025-07-21: Added `multiple_targets` to `exact_match`. Scores should not change.
...@@ -6,6 +6,7 @@ training_split: train ...@@ -6,6 +6,7 @@ training_split: train
validation_split: validation validation_split: validation
doc_to_text: "Question: {{question}}?\nAnswer:" doc_to_text: "Question: {{question}}?\nAnswer:"
doc_to_target: "{{answer.aliases}}" doc_to_target: "{{answer.aliases}}"
multiple_targets: true
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: question doc_to_decontamination_query: question
generation_kwargs: generation_kwargs:
...@@ -27,6 +28,6 @@ metric_list: ...@@ -27,6 +28,6 @@ metric_list:
higher_is_better: true higher_is_better: true
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
multi_target: true multiple_targets: true
metadata: metadata:
version: 3.0 version: 3.0
...@@ -9,6 +9,7 @@ doc_to_target: !function preprocess_winogrande.doc_to_target ...@@ -9,6 +9,7 @@ doc_to_target: !function preprocess_winogrande.doc_to_target
doc_to_choice: !function preprocess_winogrande.doc_to_choice doc_to_choice: !function preprocess_winogrande.doc_to_choice
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: sentence doc_to_decontamination_query: sentence
multiple_inputs: true
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
def doc_to_text(doc): def doc_to_target(doc) -> int:
answer_to_num = {"1": 0, "2": 1} answer_to_num = {"1": 0, "2": 1}
return answer_to_num[doc["answer"]] return answer_to_num[doc["answer"]]
def doc_to_target(doc): def doc_to_choice(doc) -> list[str]:
idx = doc["sentence"].index("_") + 1 idx = doc["sentence"].index("_") + 1
return doc["sentence"][idx:].strip() return [doc["sentence"][idx:].strip()]
def doc_to_choice(doc): def doc_to_text(doc) -> list[str]:
idx = doc["sentence"].index("_") idx = doc["sentence"].index("_")
options = [doc["option1"], doc["option2"]] options = [doc["option1"], doc["option2"]]
return [doc["sentence"][:idx] + opt for opt in options] return [doc["sentence"][:idx] + opt for opt in options]
...@@ -12,10 +12,9 @@ import os ...@@ -12,10 +12,9 @@ import os
import re import re
from collections.abc import Generator from collections.abc import Generator
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from functools import lru_cache, partial, wraps from functools import wraps
from itertools import islice from itertools import islice
from pathlib import Path from typing import Any, Callable
from typing import Any, Callable, Dict, List, Optional
import numpy as np import numpy as np
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
...@@ -27,7 +26,9 @@ HIGHER_IS_BETTER_SYMBOLS = { ...@@ -27,7 +26,9 @@ HIGHER_IS_BETTER_SYMBOLS = {
True: "↑", True: "↑",
False: "↓", False: "↓",
} }
def wrap_text(string: str, width: int = 140, **kwargs) -> Optional[str]:
def wrap_text(string: str, width: int = 140, **kwargs) -> str | None:
""" """
Wraps the given string to the specified width. Wraps the given string to the specified width.
""" """
...@@ -44,8 +45,7 @@ def wrap_text(string: str, width: int = 140, **kwargs) -> Optional[str]: ...@@ -44,8 +45,7 @@ def wrap_text(string: str, width: int = 140, **kwargs) -> Optional[str]:
) )
def get_logger(level: str | None = None) -> logging.Logger:
def get_logger(level: Optional[str] = None) -> logging.Logger:
""" """
Get a logger with a stream handler that captures all lm_eval logs. Get a logger with a stream handler that captures all lm_eval logs.
...@@ -626,3 +626,7 @@ def apply_template(template: str, doc: dict) -> str: ...@@ -626,3 +626,7 @@ def apply_template(template: str, doc: dict) -> str:
apply_template._env.filters["regex_replace"] = regex_replace apply_template._env.filters["regex_replace"] = regex_replace
return _compile_tpl(template).render(**doc) return _compile_tpl(template).render(**doc)
def validate_index(index: int, length: int) -> int:
return index if index < length else -100
"""Tests for ConfigurableTask doc_to_* methods with Jinja/YAML parsing.
This test suite documents and validates all expected YAML input types for the doc_to_* methods:
doc_to_text - Transforms a document into the input text for the model:
- String field name: References a field directly from the document
YAML: doc_to_text: "question"
- Jinja2 template: Renders a template with document fields
YAML: doc_to_text: "Question: {{question}}\nContext: {{context}}"
- Integer: Returns a constant integer value
YAML: doc_to_text: 0
- Python function: Applies a callable function (via !function directive)
YAML: doc_to_text: !function utils.my_custom_function
doc_to_target - Transforms a document into the expected target/answer:
- String field name: References a field directly from the document
YAML: doc_to_target: "answer"
- Jinja2 template: Renders a template, can return string or int for multiple choice
YAML: doc_to_target: "{{answers[correct_idx]}}"
YAML: doc_to_target: "{{label}}" # "0", "1", etc. converted to int if doc_to_choice exists
- Integer: Returns a constant integer value (typically for multiple choice)
YAML: doc_to_target: 0
- List of templates: Returns multiple targets: list[str]
YAML: doc_to_target: ["{{answer1}}", "{{answer2}}"]
- Python function: Applies a callable function
YAML: doc_to_target: !function utils.extract_answer
doc_to_choice - Defines the list of choices for multiple choice tasks:
- String field name: References a list field from the document
YAML: doc_to_choice: "options"
- Jinja2 template returning list: Template that evaluates to a list
YAML: doc_to_choice: "{{choices}}" # Must render to "['A', 'B', 'C']" format
YAML: doc_to_choice: "{{[correct, wrong]}}" # Creates list literal from fields
YAML: doc_to_choice: "{{options if options else default_options}}"
- List of templates: Each template becomes a choice
YAML: doc_to_choice: ["{{choice_a}}", "{{choice_b}}", "{{choice_c}}"]
- Dictionary: Values become the choices (keys are ignored)
YAML: doc_to_choice:
A: "First option"
B: "Second option"
C: "Third option"
- Python function: Returns a list of choices
YAML: doc_to_choice: !function utils.generate_choices
Special Jinja2 features supported:
- Filters: {{text|upper}}, {{text|lower}}, {{text|regex_replace('pattern', 'replacement')}}
- Conditionals: {{field1 if condition else field2}}
- List operations: {{', '.join(items)}}
- Nested field access: {{metadata.answer}}, {{choices[0]}}
- Math operations: {{score * 100}}
- String concatenation: {{first + ' ' + last}}
"""
from unittest.mock import Mock, patch
import pytest
from lm_eval.api.task import ConfigurableTask
class TestDocToTextMethod:
"""Test suite for doc_to_text method."""
def test_doc_to_text_with_string_field(self):
"""Test doc_to_text when config points to a field name."""
task = Mock(spec=ConfigurableTask)
task.multiple_inputs = False
task.features = ["text", "answer", "choices", "label"]
task.config = Mock()
task.config.doc_to_text = "text"
doc = {"text": "This is a test question", "answer": "A"}
result = ConfigurableTask.doc_to_text(task, doc)
assert result == "This is a test question"
def test_doc_to_text_with_jinja_template(self):
"""Test doc_to_text with Jinja template."""
task = Mock(spec=ConfigurableTask)
task.multiple_inputs = False
task.features = ["text", "answer"]
task.config = Mock()
task.config.doc_to_text = "Question: {{text}}"
doc = {"text": "What is 2+2?", "answer": "4"}
result = ConfigurableTask.doc_to_text(task, doc)
assert result == "Question: What is 2+2?"
def test_doc_to_text_with_complex_jinja(self):
"""Test doc_to_text with complex Jinja expressions."""
task = Mock(spec=ConfigurableTask)
task.multiple_inputs = False
task.features = ["text", "answer"]
task.config = Mock()
task.config.doc_to_text = "{{text|upper}} - {{answer|lower}}"
doc = {"text": "Test", "answer": "ANSWER"}
result = ConfigurableTask.doc_to_text(task, doc)
assert result == "TEST - answer"
def test_doc_to_text_with_list(self):
"""Test doc_to_text when config is an integer."""
task = Mock(spec=ConfigurableTask)
task.multiple_inputs = False
task.config = Mock()
task.config.doc_to_text = ["{{choice1}}", "{{choice2}}"]
doc = {"choice1": "1", "choice2": "2"}
result = ConfigurableTask.doc_to_text(task, doc)
assert result == ["1", "2"]
def test_doc_to_text_with_callable(self):
"""Test doc_to_text with a callable function."""
def custom_text_func(doc):
return f"Custom: {doc['text']}"
task = Mock(spec=ConfigurableTask)
task.multiple_inputs = False
task.config = Mock()
task.config.doc_to_text = custom_text_func
doc = {"text": "test"}
result = ConfigurableTask.doc_to_text(task, doc)
assert result == "Custom: test"
def test_doc_to_text_with_regex_filter(self):
"""Test doc_to_text with Jinja regex_replace filter."""
task = Mock(spec=ConfigurableTask)
task.multiple_inputs = False
task.features = ["text"]
task.config = Mock()
task.config.doc_to_text = "{{text|regex_replace('\\d+', 'X')}}"
doc = {"text": "There are 123 apples and 456 oranges"}
result = ConfigurableTask.doc_to_text(task, doc)
assert result == "There are X apples and X oranges"
def test_doc_to_text_with_list_comprehension(self):
"""Test doc_to_text with Jinja list comprehension."""
task = Mock(spec=ConfigurableTask)
task.multiple_inputs = False
task.features = []
task.config = Mock()
task.config.doc_to_text = "Options: {{ ', '.join(choices) }}"
doc = {"choices": ["red", "green", "blue"]}
result = ConfigurableTask.doc_to_text(task, doc)
assert result == "Options: red, green, blue"
def test_override_doc_to_text(self):
"""Test overriding doc_to_text with parameter."""
task = Mock(spec=ConfigurableTask)
task.multiple_inputs = False
task.features = []
task.config = Mock()
task.config.doc_to_text = "default"
doc = {"text": "test"}
result = ConfigurableTask.doc_to_text(task, doc, doc_to_text="override")
assert result == "override"
def test_doc_to_text_type_error(self):
"""Test doc_to_text raises TypeError for invalid type."""
task = Mock(spec=ConfigurableTask)
task.multiple_inputs = False
task.config = Mock()
task.config.doc_to_text = {"invalid": "type"}
doc = {"text": "test"}
with pytest.raises(TypeError):
ConfigurableTask.doc_to_text(task, doc)
def test_doc_to_text_with_missing_field(self):
"""Test doc_to_text with missing field in template."""
task = Mock(spec=ConfigurableTask)
task.multiple_inputs = False
task.features = []
task.config = Mock()
task.config.doc_to_text = "{{missing_field}}"
doc = {"text": "test"}
from jinja2 import UndefinedError
with pytest.raises(UndefinedError):
ConfigurableTask.doc_to_text(task, doc)
class TestDocToTargetMethod:
"""Test suite for doc_to_target method."""
def test_doc_to_target_with_field(self):
"""Test doc_to_target when config points to a field name."""
task = Mock(spec=ConfigurableTask)
task.features = ["text", "answer"]
task.config = Mock()
task.config.doc_to_target = "answer"
task._config = task.config
doc = {"text": "question", "answer": "correct answer"}
result = ConfigurableTask.doc_to_target(task, doc)
assert result == "correct answer"
def test_doc_to_target_with_jinja_template(self):
"""Test doc_to_target with Jinja template."""
task = Mock(spec=ConfigurableTask)
task.features = []
task.config = Mock()
task.config.doc_to_target = "{{answer}}"
task.config.doc_to_choice = None
task._config = task.config
doc = {"answer": "test_answer"}
result = ConfigurableTask.doc_to_target(task, doc)
assert result == "test_answer"
def test_doc_to_target_with_jinja_index(self):
"""Test doc_to_target with Jinja template returning numeric string."""
task = Mock(spec=ConfigurableTask)
task.features = []
task.config = Mock()
task.config.doc_to_target = "{{label}}"
task.config.doc_to_choice = ["A", "B", "C"]
task._config = task.config
doc = {"label": "1"}
result = ConfigurableTask.doc_to_target(task, doc)
assert result == 1 # Should be converted to int
def test_doc_to_target_with_int(self):
"""Test doc_to_target when config is an integer."""
task = Mock(spec=ConfigurableTask)
task.config = Mock()
task.config.doc_to_target = 0
task._config = task.config
doc = {"answer": "test"}
result = ConfigurableTask.doc_to_target(task, doc)
assert result == 0
def test_doc_to_target_with_list(self):
"""Test doc_to_target with list of templates."""
task = Mock(spec=ConfigurableTask)
task.features = []
task.config = Mock()
task.config.doc_to_target = ["{{answer}}", "{{text}}"]
task._config = task.config
doc = {"answer": "A", "text": "question"}
result = ConfigurableTask.doc_to_target(task, doc)
assert result == ["A", "question"]
def test_doc_to_target_with_int_list(self):
"""Test doc_to_target with list of templates."""
task = Mock(spec=ConfigurableTask)
task.features = []
task.multiple_targets = True
task.config = Mock()
task.config.doc_to_target = "{{answer}}"
task._config = task.config
doc = {"answer": [1, 2, 3, 4]}
result = ConfigurableTask.doc_to_target(task, doc)
assert result == [1, 2, 3, 4]
def test_doc_to_target_with_callable(self):
"""Test doc_to_target with a callable function."""
def custom_target_func(doc):
return doc["label"] * 2
task = Mock(spec=ConfigurableTask)
task.config = Mock()
task.config.doc_to_target = custom_target_func
task._config = task.config
doc = {"label": 3}
result = ConfigurableTask.doc_to_target(task, doc)
assert result == 6
def test_doc_to_target_with_nested_fields(self):
"""Test doc_to_target with nested field access."""
task = Mock(spec=ConfigurableTask)
task.features = []
task.config = Mock()
task.config.doc_to_target = "{{meta.answer}}"
task.config.doc_to_choice = None
task._config = task.config
doc = {"meta": {"answer": "nested_value"}}
result = ConfigurableTask.doc_to_target(task, doc)
assert result == "nested_value"
def test_doc_to_target_multiple_targets(self):
"""Test doc_to_target returning list for multiple targets."""
task = Mock(spec=ConfigurableTask)
task.features = []
task.config = Mock()
task.config.doc_to_target = ["{{answer1}}", "{{answer2}}"]
task._config = task.config
doc = {"answer1": "first", "answer2": "second"}
result = ConfigurableTask.doc_to_target(task, doc)
assert result == ["first", "second"]
def test_override_doc_to_target(self):
"""Test overriding doc_to_target with parameter."""
task = Mock(spec=ConfigurableTask)
task.features = []
task.config = Mock()
task.config.doc_to_target = "default"
task._config = task.config
doc = {"answer": "test"}
result = ConfigurableTask.doc_to_target(task, doc, doc_to_target="override")
assert result == "override"
def test_doc_to_target_type_error(self):
"""Test doc_to_target raises TypeError for invalid type."""
task = Mock(spec=ConfigurableTask)
task.config = Mock()
task.config.doc_to_target = {"invalid": "type"}
task._config = task.config
doc = {"answer": "test"}
with pytest.raises(TypeError):
ConfigurableTask.doc_to_target(task, doc)
def test_doc_to_target_literal_eval_edge_cases(self):
"""Test doc_to_target with edge cases for literal_eval."""
task = Mock(spec=ConfigurableTask)
task.features = []
task.config = Mock()
task.config.doc_to_choice = ["A", "B", "C"]
task._config = task.config
# Test numeric string conversion
task.config.doc_to_target = "{{label}}"
doc = {"label": "2"}
result = ConfigurableTask.doc_to_target(task, doc)
assert result == 2
# Test non-numeric string stays as string
doc = {"label": "abc"}
result = ConfigurableTask.doc_to_target(task, doc)
assert result == "abc"
# Test mixed alphanumeric stays as string
doc = {"label": "2a"}
result = ConfigurableTask.doc_to_target(task, doc)
assert result == "2a"
class TestDocToChoiceMethod:
"""Test suite for doc_to_choice method."""
def test_doc_to_choice_with_field(self):
"""Test doc_to_choice when config points to a field name."""
task = Mock(spec=ConfigurableTask)
task.features = ["choices"]
task.config = Mock()
task.config.doc_to_choice = "choices"
doc = {"choices": ["A", "B", "C", "D"]}
result = ConfigurableTask.doc_to_choice(task, doc)
assert result == ["A", "B", "C", "D"]
def test_doc_to_choice_with_jinja_list(self):
"""Test doc_to_choice with Jinja template returning list as string."""
task = Mock(spec=ConfigurableTask)
task.features = []
task.config = Mock()
task.config.doc_to_choice = "{{choices}}"
doc = {"choices": ["opt1", "opt2", "opt3"]}
# The Jinja template will render the list as a string
result = ConfigurableTask.doc_to_choice(task, doc)
assert result == ["opt1", "opt2", "opt3"]
def test_doc_to_choice_with_jinja_list_literal(self):
"""Test doc_to_choice with Jinja template creating a list literal."""
task = Mock(spec=ConfigurableTask)
task.features = []
task.config = Mock()
task.config.doc_to_choice = "{{[correct, wrong]}}"
doc = {"correct": "The right answer", "wrong": "The wrong answer"}
# The Jinja template will create a list literal and render it as a string
result = ConfigurableTask.doc_to_choice(task, doc)
assert result == ["The right answer", "The wrong answer"]
# Test with another variation
task.config.doc_to_choice = "{{[option_a, option_b, option_c]}}"
doc = {"option_a": "Choice A", "option_b": "Choice B", "option_c": "Choice C"}
result = ConfigurableTask.doc_to_choice(task, doc)
assert result == ["Choice A", "Choice B", "Choice C"]
def test_doc_to_choice_with_list_of_templates(self):
"""Test doc_to_choice with list of Jinja templates."""
task = Mock(spec=ConfigurableTask)
task.features = []
task.config = Mock()
task.config.doc_to_choice = ["{{choice_a}}", "{{choice_b}}", "{{choice_c}}"]
doc = {"choice_a": "Apple", "choice_b": "Banana", "choice_c": "Cherry"}
result = ConfigurableTask.doc_to_choice(task, doc)
assert result == ["Apple", "Banana", "Cherry"]
def test_doc_to_choice_with_dict(self):
"""Test doc_to_choice with dictionary config."""
task = Mock(spec=ConfigurableTask)
task.config = Mock()
task.config.doc_to_choice = {
"A": "First option",
"B": "Second option",
"C": "Third option",
}
doc = {}
result = ConfigurableTask.doc_to_choice(task, doc)
assert result == ["First option", "Second option", "Third option"]
def test_doc_to_choice_with_callable(self):
"""Test doc_to_choice with a callable function."""
def custom_choice_func(doc):
return [f"Option {i}" for i in range(doc["num_choices"])]
task = Mock(spec=ConfigurableTask)
task.config = Mock()
task.config.doc_to_choice = custom_choice_func
doc = {"num_choices": 3}
result = ConfigurableTask.doc_to_choice(task, doc)
assert result == ["Option 0", "Option 1", "Option 2"]
def test_doc_to_choice_none_error(self):
"""Test doc_to_choice logs error when not configured."""
task = Mock(spec=ConfigurableTask)
task.config = Mock()
task.config.doc_to_choice = None
doc = {}
# When doc_to_choice is None, it logs an error and then raises TypeError
with patch("lm_eval.api.task.eval_logger.error") as mock_error:
with pytest.raises(TypeError):
ConfigurableTask.doc_to_choice(task, doc)
mock_error.assert_called_once_with(
"doc_to_choice was called but not set in config"
)
def test_doc_to_choice_with_conditional(self):
"""Test doc_to_choice with Jinja conditional."""
task = Mock(spec=ConfigurableTask)
task.features = []
task.config = Mock()
task.config.doc_to_choice = "{{choices if has_choices else default_choices}}"
doc = {
"has_choices": True,
"choices": ["A", "B"],
"default_choices": ["X", "Y"],
}
result = ConfigurableTask.doc_to_choice(task, doc)
assert result == ["A", "B"]
def test_override_doc_to_choice(self):
"""Test overriding doc_to_choice with parameter."""
task = Mock(spec=ConfigurableTask)
task.config = Mock()
task.config.doc_to_choice = ["A", "B"]
doc = {}
result = ConfigurableTask.doc_to_choice(
task, doc, doc_to_choice=["X", "Y", "Z"]
)
assert result == ["X", "Y", "Z"]
def test_doc_to_choice_type_error(self):
"""Test doc_to_choice raises TypeError for invalid type."""
task = Mock(spec=ConfigurableTask)
task.config = Mock()
task.config.doc_to_choice = 123 # Invalid type
doc = {}
with pytest.raises(TypeError):
ConfigurableTask.doc_to_choice(task, doc)
...@@ -29,7 +29,7 @@ def get_new_tasks_else_default(): ...@@ -29,7 +29,7 @@ def get_new_tasks_else_default():
return task_classes if task_classes else TASKS return task_classes if task_classes else TASKS
def task_class(task_names=None, task_manager=None) -> ConfigurableTask: def task_class(task_names=None, task_manager=None) -> list[ConfigurableTask]:
""" """
Convert a list of task names to a list of ConfigurableTask instances Convert a list of task names to a list of ConfigurableTask instances
""" """
...@@ -98,7 +98,7 @@ class TestBaseTasks: ...@@ -98,7 +98,7 @@ class TestBaseTasks:
_array = [task.doc_to_text(doc) for doc in arr] _array = [task.doc_to_text(doc) for doc in arr]
# space convention; allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on # space convention; allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
target_delimiter: str = task.config.target_delimiter target_delimiter: str = task.config.target_delimiter
if not task.multiple_input: if not task.multiple_inputs:
for x in _array: for x in _array:
assert isinstance(x, str) assert isinstance(x, str)
assert ( assert (
...@@ -133,10 +133,7 @@ class TestBaseTasks: ...@@ -133,10 +133,7 @@ class TestBaseTasks:
_array_target = [task.doc_to_target(doc) for doc in arr] _array_target = [task.doc_to_target(doc) for doc in arr]
if task._config.output_type == "multiple_choice": if task._config.output_type == "multiple_choice":
# TODO<baber>: label can be string or int; add better test conditions # TODO<baber>: label can be string or int; add better test conditions
assert all( assert all(isinstance(label, (int, str)) for label in _array_target)
(isinstance(label, int) or isinstance(label, str))
for label in _array_target
)
def test_build_all_requests(self, task_class, limit): def test_build_all_requests(self, task_class, limit):
task_class.build_all_requests(rank=1, limit=limit, world_size=1) task_class.build_all_requests(rank=1, limit=limit, world_size=1)
...@@ -153,7 +150,7 @@ class TestBaseTasks: ...@@ -153,7 +150,7 @@ class TestBaseTasks:
# ctx is "" for multiple input tasks # ctx is "" for multiple input tasks
requests = [ requests = [
task.construct_requests( task.construct_requests(
doc=doc, ctx="" if task.multiple_input else task.doc_to_text(doc) doc=doc, ctx=[""] if task.multiple_inputs else task.doc_to_text(doc)
) )
for doc in arr for doc in arr
] ]
...@@ -187,7 +184,7 @@ class TestUnitxtTasks(TestBaseTasks): ...@@ -187,7 +184,7 @@ class TestUnitxtTasks(TestBaseTasks):
""" """
def test_check_training_docs(self, task_class: ConfigurableTask): def test_check_training_docs(self, task_class: ConfigurableTask):
if task_class.has_training_docs(): if task_class.has_training_docs:
assert task_class.dataset["train"] is not None assert task_class.dataset["train"] is not None
def test_check_validation_docs(self, task_class): def test_check_validation_docs(self, task_class):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment