evaluate_config.py 13.2 KB
Newer Older
Baber's avatar
Baber committed
1
2
import json
import logging
Baber's avatar
Baber committed
3
import textwrap
Baber's avatar
Baber committed
4
5
6
7
8
9
10
11
12
13
14
15
16
from argparse import Namespace
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import yaml

from lm_eval.utils import simple_parse_args_string


if TYPE_CHECKING:
    from lm_eval.tasks import TaskManager

Baber's avatar
Baber committed
17
eval_logger = logging.getLogger(__name__)
Baber's avatar
Baber committed
18
19
20
21
22
23
24
25
26
27
28
DICT_KEYS = [
    "wandb_args",
    "wandb_config_args",
    "hf_hub_log_args",
    "metadata",
    "model_args",
]


@dataclass
class EvaluatorConfig:
Baber's avatar
Baber committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    """Configuration for language model evaluation runs.

    This dataclass contains all parameters for configuring model evaluations via
    `simple_evaluate()` or the CLI. It supports initialization from:
    - CLI arguments (via `from_cli()`)
    - YAML configuration files (via `from_config()`)
    - Direct instantiation with keyword arguments

    The configuration handles argument parsing, validation, and preprocessing
    to ensure properly structured and validated.

    Example:
        # From CLI arguments
        config = EvaluatorConfig.from_cli(args)

        # From YAML file
        config = EvaluatorConfig.from_config("eval_config.yaml")

        # Direct instantiation
        config = EvaluatorConfig(
            model="hf",
            model_args={"pretrained": "gpt2"},
            tasks=["hellaswag", "arc_easy"],
            num_fewshot=5
        )
Baber's avatar
Baber committed
54

Baber's avatar
Baber committed
55
      See individual field documentation for detailed parameter descriptions.
Baber's avatar
Baber committed
56
57
58
59
60
61
62
63
64
65
    """

    # Core evaluation parameters
    config: Optional[str] = field(
        default=None, metadata={"help": "Path to YAML config file"}
    )
    model: str = field(default="hf", metadata={"help": "Name of model e.g. 'hf'"})
    model_args: dict = field(
        default_factory=dict, metadata={"help": "Arguments for model initialization"}
    )
66
    tasks: list[str] = field(
Baber's avatar
Baber committed
67
        default_factory=list,
68
        metadata={"help": "List of task names to evaluate"},
Baber's avatar
Baber committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    )

    # Few-shot and batching
    num_fewshot: Optional[int] = field(
        default=None, metadata={"help": "Number of examples in few-shot context"}
    )
    batch_size: int = field(default=1, metadata={"help": "Batch size for evaluation"})
    max_batch_size: Optional[int] = field(
        default=None, metadata={"help": "Maximum batch size for auto batching"}
    )

    # Device
    device: Optional[str] = field(
        default=None, metadata={"help": "Device to use (e.g. cuda, cuda:0, cpu)"}
    )

    # Data sampling and limiting
    limit: Optional[float] = field(
        default=None, metadata={"help": "Limit number of examples per task"}
    )
    samples: Union[str, dict, None] = field(
        default=None,
        metadata={"help": "dict, JSON string or path to JSON file with doc indices"},
    )

    # Caching
    use_cache: Optional[str] = field(
        default=None,
        metadata={"help": "Path to sqlite db file for caching model outputs"},
    )
    cache_requests: dict = field(
        default_factory=dict,
        metadata={"help": "Cache dataset requests: true/refresh/delete"},
    )

    # Output and logging flags
    check_integrity: bool = field(
        default=False, metadata={"help": "Run test suite for tasks"}
    )
    write_out: bool = field(
        default=False, metadata={"help": "Print prompts for first few documents"}
    )
    log_samples: bool = field(
        default=False, metadata={"help": "Save model outputs and inputs"}
    )
    output_path: Optional[str] = field(
        default=None, metadata={"help": "Dir path where result metrics will be saved"}
    )
    predict_only: bool = field(
        default=False,
        metadata={
            "help": "Only save model outputs, don't evaluate metrics. Use with log_samples."
        },
    )

    # Chat and instruction handling
    system_instruction: Optional[str] = field(
        default=None, metadata={"help": "Custom System instruction to add"}
    )
    apply_chat_template: Union[bool, str] = field(
        default=False, metadata={"help": "Apply chat template to prompt"}
    )
    fewshot_as_multiturn: bool = field(
        default=False,
        metadata={
            "help": "Use fewshot as multi-turn conversation. Requires apply_chat_template=True."
        },
    )

    # Configuration display
    show_config: bool = field(
        default=False, metadata={"help": "Show full config at end of evaluation"}
    )

    # External tasks and generation
    include_path: Optional[str] = field(
        default=None, metadata={"help": "Additional dir path for external tasks"}
    )
    gen_kwargs: Optional[dict] = field(
        default=None, metadata={"help": "Arguments for model generation"}
    )

    # Logging and verbosity
    verbosity: Optional[str] = field(
        default=None, metadata={"help": "Logging verbosity level"}
    )

    # External integrations
    wandb_args: dict = field(
        default_factory=dict, metadata={"help": "Arguments for wandb.init"}
    )
    wandb_config_args: dict = field(
        default_factory=dict, metadata={"help": "Arguments for wandb.config.update"}
    )
    hf_hub_log_args: dict = field(
        default_factory=dict, metadata={"help": "Arguments for HF Hub logging"}
    )

    # Reproducibility
    seed: list = field(
        default_factory=lambda: [0, 1234, 1234, 1234],
        metadata={"help": "Seeds for random, numpy, torch, fewshot (random)"},
    )

    # Security and safety
    trust_remote_code: bool = field(
        default=False, metadata={"help": "Trust remote code for HF datasets"}
    )
    confirm_run_unsafe_code: bool = field(
        default=False,
        metadata={
            "help": "Confirm understanding of unsafe code risks (for code tasks that executes arbitrary Python)"
        },
    )

    # Internal metadata
    metadata: dict = field(
        default_factory=dict,
        metadata={"help": "Additional metadata for tasks that require it"},
    )

    @classmethod
    def from_cli(cls, namespace: Namespace) -> "EvaluatorConfig":
        """
        Build an EvaluationConfig by merging with simple precedence:
        CLI args > YAML config > built-in defaults
        """
        # Start with built-in defaults
        config = asdict(cls())

        # Load and merge YAML config if provided
Baber's avatar
Baber committed
200
        if used_config := hasattr(namespace, "config") and namespace.config:
Baber's avatar
nit  
Baber committed
201
            config.update(cls.load_yaml_config(namespace.config))
Baber's avatar
Baber committed
202
203
204
205
206
207
208
209
210
211
212
213
214

        # Override with CLI args (only truthy values, exclude non-config args)
        excluded_args = {"config", "command", "func"}  # argparse internal args
        cli_args = {
            k: v for k, v in vars(namespace).items() if v and k not in excluded_args
        }
        config.update(cli_args)

        # Parse string arguments that should be dictionaries
        config = cls._parse_dict_args(config)

        # Create instance and validate
        instance = cls(**config)
215
        instance.configure()
Baber's avatar
Baber committed
216
217
        if used_config:
            print(textwrap.dedent(f"""{instance}"""))
Baber's avatar
Baber committed
218
219
220
221
222
223
224
225
226
227

        return instance

    @classmethod
    def from_config(cls, config_path: Union[str, Path]) -> "EvaluatorConfig":
        """
        Build an EvaluationConfig from a YAML config file.
        Merges with built-in defaults and validates.
        """
        # Load YAML config
Baber's avatar
nit  
Baber committed
228
        yaml_config = cls.load_yaml_config(config_path)
Baber's avatar
Baber committed
229
230
231
        # Parse string arguments that should be dictionaries
        yaml_config = cls._parse_dict_args(yaml_config)
        instance = cls(**yaml_config)
Baber's avatar
nit  
Baber committed
232
        instance.configure()
Baber's avatar
Baber committed
233
234
235
236

        return instance

    @staticmethod
Baber's avatar
nit  
Baber committed
237
238
239
240
241
242
243
244
245
    def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]:
        """Parse string arguments that should be dictionaries."""
        for key in config:
            if key in DICT_KEYS and isinstance(config[key], str):
                config[key] = simple_parse_args_string(config[key])
        return config

    @staticmethod
    def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
Baber's avatar
Baber committed
246
247
248
249
250
251
252
253
254
        """Load and validate YAML config file."""
        config_file = (
            Path(config_path) if not isinstance(config_path, Path) else config_path
        )
        if not config_file.is_file():
            raise FileNotFoundError(f"Config file not found: {config_path}")

        try:
            yaml_data = yaml.safe_load(config_file.read_text())
Baber's avatar
Baber committed
255
            print(textwrap.dedent(f"""yaml: {yaml_data}"""))
Baber's avatar
Baber committed
256
257
258
259
260
261
262
263
264
265
266
267
        except yaml.YAMLError as e:
            raise ValueError(f"Invalid YAML in {config_path}: {e}")
        except (OSError, UnicodeDecodeError) as e:
            raise ValueError(f"Could not read config file {config_path}: {e}")

        if not isinstance(yaml_data, dict):
            raise ValueError(
                f"YAML root must be a mapping, got {type(yaml_data).__name__}"
            )

        return yaml_data

Baber's avatar
nit  
Baber committed
268
    def configure(self) -> None:
Baber's avatar
Baber committed
269
270
        """Validate configuration and preprocess fields after creation."""
        self._validate_arguments()
Baber's avatar
Baber committed
271
        self._process_arguments()
Baber's avatar
nit  
Baber committed
272
        self._set_trust_remote_code()
Baber's avatar
Baber committed
273
274
275
276

    def _validate_arguments(self) -> None:
        """Validate configuration arguments and cross-field constraints."""
        if self.limit:
Baber's avatar
Baber committed
277
            eval_logger.warning(
Baber's avatar
Baber committed
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
                "--limit SHOULD ONLY BE USED FOR TESTING. "
                "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
            )

        # predict_only implies log_samples
        if self.predict_only:
            self.log_samples = True

        # log_samples or predict_only requires output_path
        if (self.log_samples or self.predict_only) and not self.output_path:
            raise ValueError(
                "Specify --output_path if providing --log_samples or --predict_only"
            )

        # fewshot_as_multiturn requires apply_chat_template
        if self.fewshot_as_multiturn and self.apply_chat_template is False:
            raise ValueError(
                "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set."
            )

        # samples and limit are mutually exclusive
        if self.samples and self.limit is not None:
            raise ValueError("If --samples is not None, then --limit must be None.")

        # tasks is required
        if self.tasks is None:
            raise ValueError("Need to specify task to evaluate.")

Baber's avatar
Baber committed
306
    def _process_arguments(self) -> None:
Baber's avatar
Baber committed
307
308
309
310
311
312
313
314
315
316
317
        """Process samples argument - load from file if needed."""
        if self.samples:
            if isinstance(self.samples, dict):
                self.samples = self.samples
            elif isinstance(self.samples, str):
                try:
                    self.samples = json.loads(self.samples)
                except json.JSONDecodeError:
                    if (samples_path := Path(self.samples)).is_file():
                        self.samples = json.loads(samples_path.read_text())

Baber's avatar
Baber committed
318
319
320
321
322
323
324
325
        # Set up metadata by merging model_args and metadata.
        if self.model_args is None:
            self.model_args = {}
        if self.metadata is None:
            self.metadata = {}

        self.metadata = self.model_args | self.metadata

Baber's avatar
Baber committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager":
        """Process and validate tasks, return resolved task names."""
        from lm_eval import utils
        from lm_eval.tasks import TaskManager

        # if metadata manually passed use that:
        self.metadata = metadata if metadata else self.metadata

        # Create task manager with metadata
        task_manager = TaskManager(
            include_path=self.include_path,
            metadata=self.metadata if self.metadata else {},
        )

        # self.tasks is a comma-separated string of task names
        if isinstance((task_list := self.tasks), str):
            task_list = self.tasks.split(",")
        else:
            assert isinstance(self.tasks, list), (
Baber's avatar
Baber committed
345
                "`tasks` must be a comma delimited string of task names or list[str]."
Baber's avatar
Baber committed
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
            )
        task_names = task_manager.match_tasks(task_list)

        # Check for any individual task files in the list
        for task in [task for task in task_list if task not in task_names]:
            task_path = Path(task)
            if task_path.is_file():
                config = utils.load_yaml_config(str(task_path))
                task_names.append(config)

        # Check for missing tasks
        task_missing = [
            task for task in task_list if task not in task_names and "*" not in task
        ]

        if task_missing:
            missing = ", ".join(task_missing)
            raise ValueError(f"Tasks not found: {missing}")

        # Update tasks with resolved names
        self.tasks = task_names
        return task_manager

Baber's avatar
nit  
Baber committed
369
    def _set_trust_remote_code(self) -> None:
Baber's avatar
Baber committed
370
371
372
373
374
375
376
377
378
379
380
381
382
        """Apply trust_remote_code setting if enabled."""
        if self.trust_remote_code:
            # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
            # because it's already been determined based on the prior env var before launching our
            # script--`datasets` gets imported by lm_eval internally before these lines can update the env.
            import datasets

            datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True

            # Add to model_args for the actual model initialization
            if self.model_args is None:
                self.model_args = {}
            self.model_args["trust_remote_code"] = True