main.py 4.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import csv
import os

import click
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


@click.command()
@click.argument("datadir", required=True)
def main(datadir):
    model = AutoModelForCausalLM.from_pretrained(
        # 117M
        pretrained_model_name_or_path="gpt2",
        config=AutoConfig.from_pretrained(
            "gpt2",
            # <|endoftext|>
            pad_token_id=50256,
        ),
    ).to("cuda")
    model = model.eval()

    tokenizer = AutoTokenizer.from_pretrained("gpt2")

    prompt = "The quick brown fox jumps over"
    encoded_prompt = tokenizer.encode(
        prompt, add_special_tokens=False, return_tensors="pt"
    ).to("cuda")

    # Sanity check the model
    [output_token_ids] = model.generate(
        input_ids=encoded_prompt,
        max_length=100,
        tempareture=0,
        do_sample=False,
        num_return_sequences=1,
    )
    decoded_output = tokenizer.decode(output_token_ids.tolist())
    # Next word should be "the" ("The quick brown fox jumps over *the*...")
    print(decoded_output[len(prompt + " ") :][:10])
    assert decoded_output[len(prompt + " ") :].startswith("the")

    with open(
        os.path.join(datadir, "cloze_test_test__spring2016 - cloze_test_ALL_test.csv")
    ) as f:
        storycloze_test_examples = list(csv.DictReader(f))

    example_evaluations = [
        evaluate_example(model, tokenizer, example)
        for example in storycloze_test_examples
    ]
    fraction_correct = len(
        [
            evaluation
            for evaluation in example_evaluations
            if evaluation["was_model_correct"]
        ]
    ) / float(len(example_evaluations))
    print(f"Fraction correct: {fraction_correct}")


def evaluate_example(model, tokenizer, example):
    storycloze_prompt = "{} {} {} {}".format(
        example["InputSentence1"],
        example["InputSentence2"],
        example["InputSentence3"],
        example["InputSentence4"],
    )

    # Calculate *per-token* likelihoods, as the paper did
    per_token_logit_for_sentence1 = compute_per_token_logit_for_completion(
        model, tokenizer, storycloze_prompt, example["RandomFifthSentenceQuiz1"]
    )
    per_token_logit_for_sentence2 = compute_per_token_logit_for_completion(
        model, tokenizer, storycloze_prompt, example["RandomFifthSentenceQuiz2"]
    )

    if per_token_logit_for_sentence1 > per_token_logit_for_sentence2:
        model_answer = example["RandomFifthSentenceQuiz1"]
        model_answer_code = "1"
    else:
        model_answer = example["RandomFifthSentenceQuiz2"]
        model_answer_code = "2"

    return {
        "model_answer": model_answer,
        "was_model_correct": model_answer_code == example["AnswerRightEnding"],
    }


def compute_per_token_logit_for_completion(model, tokenizer, prompt, completion):
    prompt_token_count = (
        tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
        .to("cuda")
        .shape[1]
    )
    encoded_prompt_with_completion = tokenizer.encode(
        prompt + " " + completion,
        add_special_tokens=False,
        return_tensors="pt",
    ).to("cuda")
    output_logits = model(encoded_prompt_with_completion).logits

    # Align the output logits to the input tokens.
    # The last logit needs to be dropped, because it's predicting the "next token", and it doesn't correspond to any input token
    logits_for_input_positions = output_logits[0, :-1, :]
    # The model does not predict the first input token, so it needs to be dropped as well.
    input_tokens_at_positions_with_logits = encoded_prompt_with_completion[0, 1:]
    # At each position, the model outputs ~50k logits, one for every possible token.
    # To get the logits of the tokens that were actually provided, we need to select the right logit at each position.
    logits_for_provided_tokens = torch.gather(
        logits_for_input_positions,
        1,
        input_tokens_at_positions_with_logits.unsqueeze(1),
    ).squeeze(1)

    return (
        logits_for_provided_tokens[
            prompt_token_count
            # Again, the model does not predict the first input token, so we need
            - 1 :
        ]
        .mean()
        .item()
    )


if __name__ == "__main__":
    main()