test_registry.py 14.9 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
#!/usr/bin/env python3
"""Comprehensive tests for the registry system."""

import threading

import pytest

from lm_eval.api.model import LM
from lm_eval.api.registry import (
    MetricSpec,
    Registry,
    get_metric,
    metric_agg_registry,
    metric_registry,
    model_registry,
    register_metric,
)


# Import metrics module to ensure decorators are executed
# import lm_eval.api.metrics


class TestBasicRegistry:
    """Test basic registry functionality."""

    def test_create_registry(self):
        """Test creating a basic registry."""
        reg = Registry("test")
        assert len(reg) == 0
        assert list(reg) == []

    def test_decorator_registration(self):
        """Test decorator-based registration."""
        reg = Registry("test")

        @reg.register("my_class")
        class MyClass:
            pass

        assert "my_class" in reg
        assert reg.get("my_class") == MyClass
        assert reg["my_class"] == MyClass

    def test_decorator_multiple_aliases(self):
        """Test decorator with multiple aliases."""
        reg = Registry("test")

        @reg.register("alias1", "alias2", "alias3")
        class MyClass:
            pass

        assert reg.get("alias1") == MyClass
        assert reg.get("alias2") == MyClass
        assert reg.get("alias3") == MyClass

    def test_decorator_auto_name(self):
        """Test decorator using class name when no alias provided."""
        reg = Registry("test")

        @reg.register()
        class AutoNamedClass:
            pass

        assert reg.get("AutoNamedClass") == AutoNamedClass

    def test_lazy_registration(self):
        """Test lazy loading with module paths."""
        reg = Registry("test")

        # Register with lazy loading
        reg.register("join", lazy="os.path:join")

        # Check it's stored as a string
        assert isinstance(reg._objs["join"], str)

        # Access triggers materialization
        result = reg.get("join")
        import os

        assert result == os.path.join
        assert callable(result)

    def test_direct_registration(self):
        """Test direct object registration."""
        reg = Registry("test")

        class DirectClass:
            pass

        obj = DirectClass()
        reg.register("direct", lazy=obj)

        assert reg.get("direct") == obj

    def test_metadata_removed(self):
        """Test that metadata parameter is removed from generic registry."""
        reg = Registry("test")

        # Should work without metadata parameter
        @reg.register("test_class")
        class TestClass:
            pass

        assert "test_class" in reg
        assert reg.get("test_class") == TestClass

    def test_unknown_key_error(self):
        """Test error when accessing unknown key."""
        reg = Registry("test")

        with pytest.raises(KeyError) as exc_info:
            reg.get("unknown")

        assert "Unknown test 'unknown'" in str(exc_info.value)
        assert "Available:" in str(exc_info.value)

    def test_iteration(self):
        """Test registry iteration."""
        reg = Registry("test")

        reg.register("a", lazy="os:getcwd")
        reg.register("b", lazy="os:getenv")
        reg.register("c", lazy="os:getpid")

        assert list(reg) == ["a", "b", "c"]
        assert len(reg) == 3

        # Test items()
        items = list(reg.items())
        assert len(items) == 3
        assert items[0][0] == "a"
        assert isinstance(items[0][1], str)  # Still lazy

    def test_mapping_protocol(self):
        """Test that registry implements mapping protocol."""
        reg = Registry("test")

        reg.register("test", lazy="os:getcwd")

        # __getitem__
        assert reg["test"] == reg.get("test")

        # __contains__
        assert "test" in reg
        assert "missing" not in reg

        # __iter__ and __len__ tested above


class TestTypeConstraints:
    """Test type checking and base class constraints."""

    def test_base_class_constraint(self):
        """Test base class validation."""

        # Define a base class
        class BaseClass:
            pass

        class GoodSubclass(BaseClass):
            pass

        class BadClass:
            pass

        reg = Registry("typed", base_cls=BaseClass)

        # Should work - correct subclass
        @reg.register("good")
        class GoodInline(BaseClass):
            pass

        # Should fail - wrong type
        with pytest.raises(TypeError) as exc_info:

            @reg.register("bad")
            class BadInline:
                pass

        assert "must inherit from" in str(exc_info.value)

    def test_lazy_type_check(self):
        """Test that type checking happens on materialization for lazy entries."""

        class BaseClass:
            pass

        reg = Registry("typed", base_cls=BaseClass)

        # Register a lazy entry that will fail type check
        reg.register("bad_lazy", lazy="os.path:join")

        # Should fail when accessed - the error message varies
        with pytest.raises(TypeError):
            reg.get("bad_lazy")


class TestCollisionHandling:
    """Test registration collision scenarios."""

    def test_identical_registration(self):
        """Test that identical re-registration is allowed."""
        reg = Registry("test")

        class MyClass:
            pass

        # First registration
        reg.register("test", lazy=MyClass)

        # Identical re-registration should work
        reg.register("test", lazy=MyClass)

        assert reg.get("test") == MyClass

    def test_different_registration_fails(self):
        """Test that different re-registration fails."""
        reg = Registry("test")

        class Class1:
            pass

        class Class2:
            pass

        reg.register("test", lazy=Class1)

        with pytest.raises(ValueError) as exc_info:
            reg.register("test", lazy=Class2)

        assert "already registered" in str(exc_info.value)

    def test_lazy_to_concrete_upgrade(self):
        """Test that lazy placeholder can be upgraded to concrete class."""
        reg = Registry("test")

        # Register lazy
        reg.register("myclass", lazy="test_registry:MyUpgradeClass")

        # Define and register concrete - should work
        @reg.register("myclass")
        class MyUpgradeClass:
            pass

        assert reg.get("myclass") == MyUpgradeClass


class TestThreadSafety:
    """Test thread safety of registry operations."""

    def test_concurrent_access(self):
        """Test concurrent access to lazy entries."""
        reg = Registry("test")

        # Register lazy entry
        reg.register("concurrent", lazy="os.path:join")

        results = []
        errors = []

        def access_item():
            try:
                result = reg.get("concurrent")
                results.append(result)
            except Exception as e:
                errors.append(str(e))

        # Launch threads
        threads = []
        for _ in range(10):
            t = threading.Thread(target=access_item)
            threads.append(t)
            t.start()

        # Wait for completion
        for t in threads:
            t.join()

        # Check results
        assert len(errors) == 0
        assert len(results) == 10
        # All should get the same object
        assert all(r == results[0] for r in results)

    def test_concurrent_registration(self):
        """Test concurrent registration doesn't cause issues."""
        reg = Registry("test")

        errors = []

        def register_item(name, value):
            try:
                reg.register(name, lazy=value)
            except Exception as e:
                errors.append(str(e))

        # Launch threads with different registrations
        threads = []
        for i in range(10):
            t = threading.Thread(
                target=register_item, args=(f"item_{i}", f"module{i}:Class{i}")
            )
            threads.append(t)
            t.start()

        # Wait for completion
        for t in threads:
            t.join()

        # Check results
        assert len(errors) == 0
        assert len(reg) == 10


class TestMetricRegistry:
    """Test metric-specific registry functionality."""

    def test_metric_spec(self):
        """Test MetricSpec dataclass."""

        def compute_fn(items):
            return [1 for _ in items]

        def agg_fn(values):
            return sum(values) / len(values)

        spec = MetricSpec(
            compute=compute_fn,
            aggregate=agg_fn,
            higher_is_better=True,
            output_type="probability",
        )

        assert spec.compute == compute_fn
        assert spec.aggregate == agg_fn
        assert spec.higher_is_better
        assert spec.output_type == "probability"

    def test_register_metric_decorator(self):
        """Test @register_metric decorator."""

        # Register aggregation function first
        @metric_agg_registry.register("test_mean")
        def test_mean(values):
            return sum(values) / len(values) if values else 0.0

        # Register metric
        @register_metric(
            metric="test_accuracy",
            aggregation="test_mean",
            higher_is_better=True,
            output_type="accuracy",
        )
        def compute_accuracy(items):
            return [1 if item["pred"] == item["gold"] else 0 for item in items]

        # Check registration
        assert "test_accuracy" in metric_registry
        spec = metric_registry.get("test_accuracy")
        assert isinstance(spec, MetricSpec)
        assert spec.higher_is_better
        assert spec.output_type == "accuracy"

        # Test compute function
        items = [
            {"pred": "a", "gold": "a"},
            {"pred": "b", "gold": "b"},
            {"pred": "c", "gold": "d"},
        ]
        result = spec.compute(items)
        assert result == [1, 1, 0]

        # Test aggregation
        agg_result = spec.aggregate(result)
        assert agg_result == 2 / 3

    def test_metric_without_aggregation(self):
        """Test metric registration without aggregation."""

        @register_metric(metric="no_agg", higher_is_better=False)
        def compute_something(items):
            return [len(item) for item in items]

        spec = metric_registry.get("no_agg")

        # Should raise NotImplementedError when aggregate is called
        with pytest.raises(NotImplementedError) as exc_info:
            spec.aggregate([1, 2, 3])

        assert "No aggregation function specified" in str(exc_info.value)

    def test_get_metric_helper(self):
        """Test get_metric helper function."""

        @register_metric(
            metric="helper_test",
            aggregation="mean",  # Assuming 'mean' exists in metric_agg_registry
        )
        def compute_helper(items):
            return items

        # get_metric returns just the compute function
        compute_fn = get_metric("helper_test")
        assert callable(compute_fn)
        assert compute_fn([1, 2, 3]) == [1, 2, 3]


class TestRegistryUtilities:
    """Test utility methods."""

    def test_freeze(self):
        """Test freezing a registry."""
        reg = Registry("test")

        # Add some items
        reg.register("item1", lazy="os:getcwd")
        reg.register("item2", lazy="os:getenv")

        # Freeze the registry
        reg.freeze()

        # Should not be able to register new items
        with pytest.raises(TypeError):
            reg._objs["new"] = "value"

        # Should still be able to access items
        assert "item1" in reg
        assert callable(reg.get("item1"))

    def test_clear(self):
        """Test clearing a registry."""
        reg = Registry("test")

        # Add items
        reg.register("item1", lazy="os:getcwd")
        reg.register("item2", lazy="os:getenv")

        assert len(reg) == 2

        # Clear
        reg._clear()

        assert len(reg) == 0
        assert list(reg) == []

    def test_origin(self):
        """Test origin tracking."""
        reg = Registry("test")

        # Lazy entry - no origin
        reg.register("lazy", lazy="os:getcwd")
        assert reg.origin("lazy") is None

        # Concrete class - should have origin
        @reg.register("concrete")
        class ConcreteClass:
            pass

        origin = reg.origin("concrete")
        assert origin is not None
        assert "test_registry.py" in origin
        assert ":" in origin  # Has line number


class TestBackwardCompatibility:
    """Test backward compatibility features."""

    def test_model_registry_alias(self):
        """Test MODEL_REGISTRY backward compatibility."""
        from lm_eval.api.registry import MODEL_REGISTRY

        # Should be same object as model_registry
        assert MODEL_REGISTRY is model_registry

        # Should reflect current state
        before_count = len(MODEL_REGISTRY)

        # Add new model
        @model_registry.register("test_model_compat")
        class TestModelCompat(LM):
            pass

        # MODEL_REGISTRY should immediately reflect the change
        assert len(MODEL_REGISTRY) == before_count + 1
        assert "test_model_compat" in MODEL_REGISTRY

    def test_legacy_functions(self):
        """Test legacy helper functions."""
        from lm_eval.api.registry import (
            AGGREGATION_REGISTRY,
            DEFAULT_METRIC_REGISTRY,
            get_model,
            register_model,
        )

        # register_model should work
        @register_model("legacy_model")
        class LegacyModel(LM):
            pass

        # get_model should work
        assert get_model("legacy_model") == LegacyModel

        # Check other aliases
        assert DEFAULT_METRIC_REGISTRY is metric_registry
        assert AGGREGATION_REGISTRY is metric_agg_registry


class TestEdgeCases:
    """Test edge cases and error conditions."""

    def test_invalid_lazy_format(self):
        """Test error on invalid lazy format."""
        reg = Registry("test")

        reg.register("bad", lazy="no_colon_here")

        with pytest.raises(ValueError) as exc_info:
            reg.get("bad")

        assert "expected 'module:object'" in str(exc_info.value)

    def test_lazy_module_not_found(self):
        """Test error when lazy module doesn't exist."""
        reg = Registry("test")

        reg.register("missing", lazy="nonexistent_module:Class")

        with pytest.raises(ModuleNotFoundError):
            reg.get("missing")

    def test_lazy_attribute_not_found(self):
        """Test error when lazy attribute doesn't exist."""
        reg = Registry("test")

        reg.register("missing_attr", lazy="os:nonexistent_function")

        with pytest.raises(AttributeError):
            reg.get("missing_attr")

    def test_multiple_aliases_with_lazy(self):
        """Test that multiple aliases with lazy fails."""
        reg = Registry("test")

        with pytest.raises(ValueError) as exc_info:
            reg.register("alias1", "alias2", lazy="os:getcwd")

        assert "Exactly one alias required" in str(exc_info.value)


if __name__ == "__main__":
    pytest.main([__file__, "-v"])