evaluate_config.py 12.6 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
import json
import logging
import warnings
from argparse import Namespace
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import yaml

from lm_eval.utils import simple_parse_args_string


if TYPE_CHECKING:
    from lm_eval.tasks import TaskManager


DICT_KEYS = [
    "wandb_args",
    "wandb_config_args",
    "hf_hub_log_args",
    "metadata",
    "model_args",
]


@dataclass
class EvaluatorConfig:
    """
    Configuration container for initializing evaluator or simple_evaluate.

    This dataclass holds all the parameters needed for running evaluations,
    with sensible defaults and documentation for each field.
    """

    # Core evaluation parameters
    config: Optional[str] = field(
        default=None, metadata={"help": "Path to YAML config file"}
    )
    model: str = field(default="hf", metadata={"help": "Name of model e.g. 'hf'"})
    model_args: dict = field(
        default_factory=dict, metadata={"help": "Arguments for model initialization"}
    )
    tasks: Union[str, list[str]] = field(
        default_factory=list,
        metadata={"help": "Comma-separated list of task names to evaluate"},
    )

    # Few-shot and batching
    num_fewshot: Optional[int] = field(
        default=None, metadata={"help": "Number of examples in few-shot context"}
    )
    batch_size: int = field(default=1, metadata={"help": "Batch size for evaluation"})
    max_batch_size: Optional[int] = field(
        default=None, metadata={"help": "Maximum batch size for auto batching"}
    )

    # Device
    device: Optional[str] = field(
        default=None, metadata={"help": "Device to use (e.g. cuda, cuda:0, cpu)"}
    )

    # Data sampling and limiting
    limit: Optional[float] = field(
        default=None, metadata={"help": "Limit number of examples per task"}
    )
    samples: Union[str, dict, None] = field(
        default=None,
        metadata={"help": "dict, JSON string or path to JSON file with doc indices"},
    )

    # Caching
    use_cache: Optional[str] = field(
        default=None,
        metadata={"help": "Path to sqlite db file for caching model outputs"},
    )
    cache_requests: dict = field(
        default_factory=dict,
        metadata={"help": "Cache dataset requests: true/refresh/delete"},
    )

    # Output and logging flags
    check_integrity: bool = field(
        default=False, metadata={"help": "Run test suite for tasks"}
    )
    write_out: bool = field(
        default=False, metadata={"help": "Print prompts for first few documents"}
    )
    log_samples: bool = field(
        default=False, metadata={"help": "Save model outputs and inputs"}
    )
    output_path: Optional[str] = field(
        default=None, metadata={"help": "Dir path where result metrics will be saved"}
    )
    predict_only: bool = field(
        default=False,
        metadata={
            "help": "Only save model outputs, don't evaluate metrics. Use with log_samples."
        },
    )

    # Chat and instruction handling
    system_instruction: Optional[str] = field(
        default=None, metadata={"help": "Custom System instruction to add"}
    )
    apply_chat_template: Union[bool, str] = field(
        default=False, metadata={"help": "Apply chat template to prompt"}
    )
    fewshot_as_multiturn: bool = field(
        default=False,
        metadata={
            "help": "Use fewshot as multi-turn conversation. Requires apply_chat_template=True."
        },
    )

    # Configuration display
    show_config: bool = field(
        default=False, metadata={"help": "Show full config at end of evaluation"}
    )

    # External tasks and generation
    include_path: Optional[str] = field(
        default=None, metadata={"help": "Additional dir path for external tasks"}
    )
    gen_kwargs: Optional[dict] = field(
        default=None, metadata={"help": "Arguments for model generation"}
    )

    # Logging and verbosity
    verbosity: Optional[str] = field(
        default=None, metadata={"help": "Logging verbosity level"}
    )

    # External integrations
    wandb_args: dict = field(
        default_factory=dict, metadata={"help": "Arguments for wandb.init"}
    )
    wandb_config_args: dict = field(
        default_factory=dict, metadata={"help": "Arguments for wandb.config.update"}
    )
    hf_hub_log_args: dict = field(
        default_factory=dict, metadata={"help": "Arguments for HF Hub logging"}
    )

    # Reproducibility
    seed: list = field(
        default_factory=lambda: [0, 1234, 1234, 1234],
        metadata={"help": "Seeds for random, numpy, torch, fewshot (random)"},
    )

    # Security and safety
    trust_remote_code: bool = field(
        default=False, metadata={"help": "Trust remote code for HF datasets"}
    )
    confirm_run_unsafe_code: bool = field(
        default=False,
        metadata={
            "help": "Confirm understanding of unsafe code risks (for code tasks that executes arbitrary Python)"
        },
    )

    # Internal metadata
    metadata: dict = field(
        default_factory=dict,
        metadata={"help": "Additional metadata for tasks that require it"},
    )

    @staticmethod
    def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]:
        """Parse string arguments that should be dictionaries."""
        for key in config:
            if key in DICT_KEYS and isinstance(config[key], str):
                config[key] = simple_parse_args_string(config[key])
        return config

    @classmethod
    def from_cli(cls, namespace: Namespace) -> "EvaluatorConfig":
        """
        Build an EvaluationConfig by merging with simple precedence:
        CLI args > YAML config > built-in defaults
        """
        # Start with built-in defaults
        config = asdict(cls())

        # Load and merge YAML config if provided
        if hasattr(namespace, "config") and namespace.config:
            config.update(cls._load_yaml_config(namespace.config))

        # Override with CLI args (only truthy values, exclude non-config args)
        excluded_args = {"config", "command", "func"}  # argparse internal args
        cli_args = {
            k: v for k, v in vars(namespace).items() if v and k not in excluded_args
        }
        config.update(cli_args)

        # Parse string arguments that should be dictionaries
        config = cls._parse_dict_args(config)

        # Create instance and validate
        instance = cls(**config)
        instance.validate_and_preprocess()

        return instance

    @classmethod
    def from_config(cls, config_path: Union[str, Path]) -> "EvaluatorConfig":
        """
        Build an EvaluationConfig from a YAML config file.
        Merges with built-in defaults and validates.
        """
        # Load YAML config
        yaml_config = cls._load_yaml_config(config_path)

        # Parse string arguments that should be dictionaries
        yaml_config = cls._parse_dict_args(yaml_config)

        # Create instance and validate
        instance = cls(**yaml_config)
        instance.validate_and_preprocess()

        return instance

    @staticmethod
    def _load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
        """Load and validate YAML config file."""
        config_file = (
            Path(config_path) if not isinstance(config_path, Path) else config_path
        )
        if not config_file.is_file():
            raise FileNotFoundError(f"Config file not found: {config_path}")

        try:
            yaml_data = yaml.safe_load(config_file.read_text())
        except yaml.YAMLError as e:
            raise ValueError(f"Invalid YAML in {config_path}: {e}")
        except (OSError, UnicodeDecodeError) as e:
            raise ValueError(f"Could not read config file {config_path}: {e}")

        if not isinstance(yaml_data, dict):
            raise ValueError(
                f"YAML root must be a mapping, got {type(yaml_data).__name__}"
            )

        return yaml_data

    def validate_and_preprocess(self) -> None:
        """Validate configuration and preprocess fields after creation."""
        self._validate_arguments()
        self._process_samples()
        self._setup_metadata()
        self._apply_trust_remote_code()

    def _validate_arguments(self) -> None:
        """Validate configuration arguments and cross-field constraints."""
        if self.limit:
            warnings.warn(
                "--limit SHOULD ONLY BE USED FOR TESTING. "
                "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
            )

        # predict_only implies log_samples
        if self.predict_only:
            self.log_samples = True

        # log_samples or predict_only requires output_path
        if (self.log_samples or self.predict_only) and not self.output_path:
            raise ValueError(
                "Specify --output_path if providing --log_samples or --predict_only"
            )

        # fewshot_as_multiturn requires apply_chat_template
        if self.fewshot_as_multiturn and self.apply_chat_template is False:
            raise ValueError(
                "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set."
            )

        # samples and limit are mutually exclusive
        if self.samples and self.limit is not None:
            raise ValueError("If --samples is not None, then --limit must be None.")

        # tasks is required
        if self.tasks is None:
            raise ValueError("Need to specify task to evaluate.")

    def _process_samples(self) -> None:
        """Process samples argument - load from file if needed."""
        if self.samples:
            if isinstance(self.samples, dict):
                self.samples = self.samples
            elif isinstance(self.samples, str):
                try:
                    self.samples = json.loads(self.samples)
                except json.JSONDecodeError:
                    if (samples_path := Path(self.samples)).is_file():
                        self.samples = json.loads(samples_path.read_text())

    def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager":
        """Process and validate tasks, return resolved task names."""
        from lm_eval import utils
        from lm_eval.tasks import TaskManager

        # if metadata manually passed use that:
        self.metadata = metadata if metadata else self.metadata

        # Create task manager with metadata
        task_manager = TaskManager(
            include_path=self.include_path,
            metadata=self.metadata if self.metadata else {},
        )

        # self.tasks is a comma-separated string of task names
        if isinstance((task_list := self.tasks), str):
            task_list = self.tasks.split(",")
        else:
            assert isinstance(self.tasks, list), (
                "Tasks must be a comma delimited string of task names or list[str]."
            )
        task_names = task_manager.match_tasks(task_list)

        # Check for any individual task files in the list
        for task in [task for task in task_list if task not in task_names]:
            task_path = Path(task)
            if task_path.is_file():
                config = utils.load_yaml_config(str(task_path))
                task_names.append(config)

        # Check for missing tasks
        task_missing = [
            task for task in task_list if task not in task_names and "*" not in task
        ]

        if task_missing:
            missing = ", ".join(task_missing)
            raise ValueError(f"Tasks not found: {missing}")

        # Update tasks with resolved names
        self.tasks = task_names
        return task_manager

    def _setup_metadata(self) -> None:
        """Set up metadata by merging model_args and metadata."""
        if self.model_args is None:
            self.model_args = {}
        if self.metadata is None:
            self.metadata = {}

        self.metadata = self.model_args | self.metadata

    def _apply_trust_remote_code(self) -> None:
        """Apply trust_remote_code setting if enabled."""
        if self.trust_remote_code:
            eval_logger = logging.getLogger(__name__)
            eval_logger.info("Setting HF_DATASETS_TRUST_REMOTE_CODE=true")

            # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
            # because it's already been determined based on the prior env var before launching our
            # script--`datasets` gets imported by lm_eval internally before these lines can update the env.
            import datasets

            datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True

            # Add to model_args for the actual model initialization
            if self.model_args is None:
                self.model_args = {}
            self.model_args["trust_remote_code"] = True