evaluate_config.py 13.1 KB
Newer Older
Baber's avatar
Baber committed
1
2
import json
import logging
Baber's avatar
Baber committed
3
import textwrap
Baber's avatar
Baber committed
4
5
6
7
8
9
10
11
12
13
14
15
16
from argparse import Namespace
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import yaml

from lm_eval.utils import simple_parse_args_string


if TYPE_CHECKING:
    from lm_eval.tasks import TaskManager

Baber's avatar
Baber committed
17
eval_logger = logging.getLogger(__name__)
Baber's avatar
Baber committed
18
19
20
21
22
23
DICT_KEYS = [
    "wandb_args",
    "wandb_config_args",
    "hf_hub_log_args",
    "metadata",
    "model_args",
Baber's avatar
Baber committed
24
    "gen_kwargs",
Baber's avatar
Baber committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
]


@dataclass
class EvaluatorConfig:
    """Configuration for language model evaluation runs.

    This dataclass contains all parameters for configuring model evaluations via
    `simple_evaluate()` or the CLI. It supports initialization from:
    - CLI arguments (via `from_cli()`)
    - YAML configuration files (via `from_config()`)
    - Direct instantiation with keyword arguments

    The configuration handles argument parsing, validation, and preprocessing
    to ensure properly structured and validated.

    Example:
        # From CLI arguments
        config = EvaluatorConfig.from_cli(args)

        # From YAML file
        config = EvaluatorConfig.from_config("eval_config.yaml")

        # Direct instantiation
        config = EvaluatorConfig(
            model="hf",
            model_args={"pretrained": "gpt2"},
            tasks=["hellaswag", "arc_easy"],
            num_fewshot=5
        )

      See individual field documentation for detailed parameter descriptions.
    """

    # Core evaluation parameters
    config: Optional[str] = field(
        default=None, metadata={"help": "Path to YAML config file"}
    )
    model: str = field(default="hf", metadata={"help": "Name of model e.g. 'hf'"})
    model_args: dict = field(
        default_factory=dict, metadata={"help": "Arguments for model initialization"}
    )
    tasks: Union[str, list[str]] = field(
        default_factory=list,
        metadata={"help": "Comma-separated list of task names to evaluate"},
    )

    # Few-shot and batching
    num_fewshot: Optional[int] = field(
        default=None, metadata={"help": "Number of examples in few-shot context"}
    )
    batch_size: int = field(default=1, metadata={"help": "Batch size for evaluation"})
    max_batch_size: Optional[int] = field(
        default=None, metadata={"help": "Maximum batch size for auto batching"}
    )

    # Device
    device: Optional[str] = field(
Baber's avatar
Baber committed
83
        default="cuda:0", metadata={"help": "Device to use (e.g. cuda, cuda:0, cpu)"}
Baber's avatar
Baber committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    )

    # Data sampling and limiting
    limit: Optional[float] = field(
        default=None, metadata={"help": "Limit number of examples per task"}
    )
    samples: Union[str, dict, None] = field(
        default=None,
        metadata={"help": "dict, JSON string or path to JSON file with doc indices"},
    )

    # Caching
    use_cache: Optional[str] = field(
        default=None,
        metadata={"help": "Path to sqlite db file for caching model outputs"},
    )
    cache_requests: dict = field(
        default_factory=dict,
        metadata={"help": "Cache dataset requests: true/refresh/delete"},
    )

    # Output and logging flags
    check_integrity: bool = field(
        default=False, metadata={"help": "Run test suite for tasks"}
    )
    write_out: bool = field(
        default=False, metadata={"help": "Print prompts for first few documents"}
    )
    log_samples: bool = field(
        default=False, metadata={"help": "Save model outputs and inputs"}
    )
    output_path: Optional[str] = field(
        default=None, metadata={"help": "Dir path where result metrics will be saved"}
    )
    predict_only: bool = field(
        default=False,
        metadata={
            "help": "Only save model outputs, don't evaluate metrics. Use with log_samples."
        },
    )

    # Chat and instruction handling
    system_instruction: Optional[str] = field(
        default=None, metadata={"help": "Custom System instruction to add"}
    )
    apply_chat_template: Union[bool, str] = field(
Baber's avatar
Baber committed
130
131
132
133
        default=False,
        metadata={
            "help": "Apply chat template to prompt. Either True, or a string identifying the tokenizer template."
        },
Baber's avatar
Baber committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    )
    fewshot_as_multiturn: bool = field(
        default=False,
        metadata={
            "help": "Use fewshot as multi-turn conversation. Requires apply_chat_template=True."
        },
    )

    # Configuration display
    show_config: bool = field(
        default=False, metadata={"help": "Show full config at end of evaluation"}
    )

    # External tasks and generation
    include_path: Optional[str] = field(
        default=None, metadata={"help": "Additional dir path for external tasks"}
    )
    gen_kwargs: Optional[dict] = field(
        default=None, metadata={"help": "Arguments for model generation"}
    )

    # Logging and verbosity
    verbosity: Optional[str] = field(
        default=None, metadata={"help": "Logging verbosity level"}
    )

    # External integrations
    wandb_args: dict = field(
        default_factory=dict, metadata={"help": "Arguments for wandb.init"}
    )
    wandb_config_args: dict = field(
        default_factory=dict, metadata={"help": "Arguments for wandb.config.update"}
    )
    hf_hub_log_args: dict = field(
        default_factory=dict, metadata={"help": "Arguments for HF Hub logging"}
    )

    # Reproducibility
    seed: list = field(
        default_factory=lambda: [0, 1234, 1234, 1234],
        metadata={"help": "Seeds for random, numpy, torch, fewshot (random)"},
    )

Baber's avatar
Baber committed
177
    # Security
Baber's avatar
Baber committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    trust_remote_code: bool = field(
        default=False, metadata={"help": "Trust remote code for HF datasets"}
    )
    confirm_run_unsafe_code: bool = field(
        default=False,
        metadata={
            "help": "Confirm understanding of unsafe code risks (for code tasks that executes arbitrary Python)"
        },
    )

    # Internal metadata
    metadata: dict = field(
        default_factory=dict,
        metadata={"help": "Additional metadata for tasks that require it"},
    )

    @classmethod
    def from_cli(cls, namespace: Namespace) -> "EvaluatorConfig":
        """
        Build an EvaluationConfig by merging with simple precedence:
        CLI args > YAML config > built-in defaults
        """
        # Start with built-in defaults
        config = asdict(cls())

        # Load and merge YAML config if provided
Baber's avatar
Baber committed
204
205
        if used_config := hasattr(namespace, "config") and namespace.config:
            config.update(cls.load_yaml_config(namespace.config))
Baber's avatar
Baber committed
206
207

        # Override with CLI args (only truthy values, exclude non-config args)
Baber's avatar
Baber committed
208
        excluded_args = {"command", "func"}  # argparse internal args
Baber's avatar
Baber committed
209
210
211
212
213
214
215
216
217
218
        cli_args = {
            k: v for k, v in vars(namespace).items() if v and k not in excluded_args
        }
        config.update(cli_args)

        # Parse string arguments that should be dictionaries
        config = cls._parse_dict_args(config)

        # Create instance and validate
        instance = cls(**config)
Baber's avatar
Baber committed
219
220
221
        if used_config:
            print(textwrap.dedent(f"""{instance}"""))
        instance.configure()
Baber's avatar
Baber committed
222
223
224
225
226
227
228
229
230
231

        return instance

    @classmethod
    def from_config(cls, config_path: Union[str, Path]) -> "EvaluatorConfig":
        """
        Build an EvaluationConfig from a YAML config file.
        Merges with built-in defaults and validates.
        """
        # Load YAML config
Baber's avatar
Baber committed
232
        yaml_config = cls.load_yaml_config(config_path)
Baber's avatar
Baber committed
233
234
235
        # Parse string arguments that should be dictionaries
        yaml_config = cls._parse_dict_args(yaml_config)
        instance = cls(**yaml_config)
Baber's avatar
Baber committed
236
        instance.configure()
Baber's avatar
Baber committed
237
238
239
240

        return instance

    @staticmethod
Baber's avatar
Baber committed
241
242
243
244
245
246
247
248
249
    def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]:
        """Parse string arguments that should be dictionaries."""
        for key in config:
            if key in DICT_KEYS and isinstance(config[key], str):
                config[key] = simple_parse_args_string(config[key])
        return config

    @staticmethod
    def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
Baber's avatar
Baber committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        """Load and validate YAML config file."""
        config_file = (
            Path(config_path) if not isinstance(config_path, Path) else config_path
        )
        if not config_file.is_file():
            raise FileNotFoundError(f"Config file not found: {config_path}")

        try:
            yaml_data = yaml.safe_load(config_file.read_text())
        except yaml.YAMLError as e:
            raise ValueError(f"Invalid YAML in {config_path}: {e}")
        except (OSError, UnicodeDecodeError) as e:
            raise ValueError(f"Could not read config file {config_path}: {e}")

        if not isinstance(yaml_data, dict):
            raise ValueError(
                f"YAML root must be a mapping, got {type(yaml_data).__name__}"
            )

        return yaml_data

Baber's avatar
Baber committed
271
    def configure(self) -> None:
Baber's avatar
Baber committed
272
273
274
        """Validate configuration and preprocess fields after creation."""
        self._validate_arguments()
        self._process_arguments()
Baber's avatar
Baber committed
275
        self._set_trust_remote_code()
Baber's avatar
Baber committed
276
277
278
279

    def _validate_arguments(self) -> None:
        """Validate configuration arguments and cross-field constraints."""
        if self.limit:
Baber's avatar
Baber committed
280
            eval_logger.warning(
Baber's avatar
Baber committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
                "--limit SHOULD ONLY BE USED FOR TESTING. "
                "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
            )

        # predict_only implies log_samples
        if self.predict_only:
            self.log_samples = True

        # log_samples or predict_only requires output_path
        if (self.log_samples or self.predict_only) and not self.output_path:
            raise ValueError(
                "Specify --output_path if providing --log_samples or --predict_only"
            )

        # fewshot_as_multiturn requires apply_chat_template
        if self.fewshot_as_multiturn and self.apply_chat_template is False:
            raise ValueError(
                "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set."
            )

        # samples and limit are mutually exclusive
        if self.samples and self.limit is not None:
            raise ValueError("If --samples is not None, then --limit must be None.")

        # tasks is required
        if self.tasks is None:
            raise ValueError("Need to specify task to evaluate.")

    def _process_arguments(self) -> None:
        """Process samples argument - load from file if needed."""
        if self.samples:
            if isinstance(self.samples, dict):
                self.samples = self.samples
            elif isinstance(self.samples, str):
                try:
                    self.samples = json.loads(self.samples)
                except json.JSONDecodeError:
                    if (samples_path := Path(self.samples)).is_file():
                        self.samples = json.loads(samples_path.read_text())

        # Set up metadata by merging model_args and metadata.
        if self.model_args is None:
            self.model_args = {}
        if self.metadata is None:
            self.metadata = {}

        self.metadata = self.model_args | self.metadata

    def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager":
        """Process and validate tasks, return resolved task names."""
        from lm_eval import utils
        from lm_eval.tasks import TaskManager

        # if metadata manually passed use that:
        self.metadata = metadata if metadata else self.metadata

        # Create task manager with metadata
        task_manager = TaskManager(
            include_path=self.include_path,
            metadata=self.metadata if self.metadata else {},
        )

Baber's avatar
Baber committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        task_names = self.tasks
        # TODO: FIX TASKS VALIDATION!!!
        # task_names = task_manager.match_tasks(self.tasks)

        # # Check for any individual task files in the list
        # for task in [task for task in self.tasks if task not in task_names]:
        #     task_path = Path(task)
        #     if task_path.is_file():
        #         config = utils.load_yaml_config(str(task_path))
        #         task_names.append(config)
        #
        # # Check for missing tasks
        # task_missing = [
        #     task for task in self.tasks if task not in task_names and "*" not in task
        # ]
        #
        # if task_missing:
        #     missing = ", ".join(task_missing)
        #     raise ValueError(f"Tasks not found: {missing}")
Baber's avatar
Baber committed
362
363
364
365
366

        # Update tasks with resolved names
        self.tasks = task_names
        return task_manager

Baber's avatar
Baber committed
367
    def _set_trust_remote_code(self) -> None:
Baber's avatar
Baber committed
368
369
370
371
372
373
374
375
376
377
378
379
380
        """Apply trust_remote_code setting if enabled."""
        if self.trust_remote_code:
            # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
            # because it's already been determined based on the prior env var before launching our
            # script--`datasets` gets imported by lm_eval internally before these lines can update the env.
            import datasets

            datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True

            # Add to model_args for the actual model initialization
            if self.model_args is None:
                self.model_args = {}
            self.model_args["trust_remote_code"] = True