transformer-lens.py 1.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import warnings

import torch
import torch.nn as nn
from transformer_lens import HookedTransformer
from transformers import AutoConfig

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM


def evaluate_lm_eval(lens_model: HookedTransformer, tasks: list[str], **kwargs):
    class HFLikeModelAdapter(nn.Module):
        """Adapts HookedTransformer to match the HuggingFace interface expected by lm-eval"""

        def __init__(self, model: HookedTransformer):
            super().__init__()
            self.model = model
            self.tokenizer = model.tokenizer
            self.config = AutoConfig.from_pretrained(model.cfg.tokenizer_name)
            self.device = model.cfg.device
            self.tie_weights = lambda: self

        def forward(self, input_ids=None, attention_mask=None, **kwargs):
            output = self.model(input_ids, attention_mask=attention_mask, **kwargs)
            # Make sure output has the expected .logits attribute
            if not hasattr(output, "logits"):
                if isinstance(output, torch.Tensor):
                    output.logits = output
            return output

        # Only delegate specific attributes we know we need
        def to(self, *args, **kwargs):
            return self.model.to(*args, **kwargs)

        def eval(self):
            self.model.eval()
            return self

        def train(self, mode=True):
            self.model.train(mode)
            return self

    model = HFLikeModelAdapter(lens_model)
    warnings.filterwarnings("ignore", message="Failed to get model SHA for")
    results = evaluator.simple_evaluate(
        model=HFLM(pretrained=model, tokenizer=model.tokenizer),
        tasks=tasks,
        verbosity="WARNING",
        **kwargs,
    )
    return results


if __name__ == "__main__":
    # Load base model
    model = HookedTransformer.from_pretrained("pythia-70m")
    res = evaluate_lm_eval(model, tasks=["arc_easy"])
    print(res["results"])