"stubs/torch/utils/checkpoint.pyi" did not exist on "8634280c831c0995220de8a35cb292fd2c67095a"
kl_divergence.py 5.43 KB
Newer Older
Casper's avatar
Casper committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# adapted from https://gist.github.com/Ttl/0d51f739dc59254b4b2183e259c97d82

import torch
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from transformers import (
    PreTrainedTokenizer,
    PreTrainedModel,
    AutoModelForCausalLM,
    AutoTokenizer,
)

try:
    from scipy.stats import bayes_mvs
    from scipy.stats import t as student_t
    from scipy.stats.mstats import mquantiles_cimj
Casper's avatar
Casper committed
18

Casper's avatar
Casper committed
19
20
21
22
    SCIPY_INSTALLED = True
except:
    SCIPY_INSTALLED = False

Casper's avatar
Casper committed
23

Casper's avatar
Casper committed
24
25
26
27
@torch.jit.script
def rel_entr(x, y):
    mask = (x > 0) & (y > 0)
    result = torch.where(mask, x * torch.log(x / y), torch.zeros_like(x))
Casper's avatar
Casper committed
28
    result[(x > 0) & (y <= 0)] = float("inf")
Casper's avatar
Casper committed
29
30
    return result

Casper's avatar
Casper committed
31

Casper's avatar
Casper committed
32
33
34
35
36
37
38
def bin_conf(p, n, z):
    # Binomial distribution confidence bounds
    # Bayes estimator when p is degenerate
    if p == 0:
        p = 1 / (n + 2)
    if p == 1:
        p = 1 - 1 / (n + 2)
Casper's avatar
Casper committed
39
40
    return z * torch.sqrt(p * (1 - p) / n)

Casper's avatar
Casper committed
41

Casper's avatar
Casper committed
42
43
44
45
46
47
def eval_kl_divergence(
    ref_model: PreTrainedModel,
    eval_model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    seqlen: int,
):
Casper's avatar
Casper committed
48
    if not SCIPY_INSTALLED:
Casper's avatar
Casper committed
49
50
51
        raise Exception(
            "SciPy needs to be installed for KL Divergence evaluation: pip install scipy"
        )
Casper's avatar
Casper committed
52
53

    # load dataset
Casper's avatar
Casper committed
54
55
    data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    data = tokenizer("\n\n".join(data["text"]), return_tensors="pt")
Casper's avatar
Casper committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    data = data.input_ids.to(ref_model.device)

    n_samples = data.numel() // seqlen

    alpha = 0.01
    kls = []
    top1 = 0
    top5 = 0
    top10 = 0
    eval_top5 = 0
    eval_top10 = 0
    samples = 0
    i = 0

    # start eval
    with tqdm(range(n_samples), desc="KL Div") as progress_bar:
        for i in progress_bar:
Casper's avatar
Casper committed
73
74
75
            start_index = i * seqlen
            end_index = (i + 1) * seqlen
            batch_len = end_index - start_index
Casper's avatar
Casper committed
76
            batch = data[:, start_index:end_index]
Casper's avatar
Casper committed
77

Casper's avatar
Casper committed
78
79
80
81
82
83
84
85
86
87
88
            # get logits
            with torch.no_grad():
                y1 = ref_model(batch)[0]
                y2 = eval_model(batch)[0]

            # kl divergence
            y1_probs = torch.softmax(y1, dim=-1)
            y2_probs = torch.softmax(y2, dim=-1)
            relative_entropy = rel_entr(y1_probs, y2_probs)
            kl_div = torch.sum(relative_entropy, dim=-1).squeeze(0)
            kls.append(torch.nan_to_num(kl_div).tolist())
Casper's avatar
Casper committed
89

Casper's avatar
Casper committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
            # stats
            eval_argmax = torch.argmax(y2, axis=-1).squeeze(0)
            ref_argmax = torch.argmax(y1, axis=-1).squeeze(0)
            eval_part5 = torch.topk(y2, k=5, dim=-1).indices[:, :, -5].squeeze(0)
            ref_part5 = torch.topk(y1, k=5, dim=-1).indices[:, :, -5].squeeze(0)
            eval_part10 = torch.topk(y2, k=10, dim=-1).indices[:, :, -10].squeeze(0)
            ref_part10 = torch.topk(y1, k=10, dim=-1).indices[:, :, -10].squeeze(0)
            top1 += (eval_argmax == ref_argmax).sum().item()
            top5 += ((ref_argmax == eval_part5).sum()).item()
            top10 += ((ref_argmax == eval_part10).sum()).item()
            eval_top5 += ((eval_argmax == ref_part5).sum()).item()
            eval_top10 += ((eval_argmax == ref_part10).sum()).item()
            samples += batch_len

            progress_bar.set_description(
                f"KL Div: {torch.mean(torch.Tensor(kls)):.4g}, "
                f"Top 1: {top1 / samples:.4g}, "
                f"Top 5: {top5 / samples:.4g}, "
                f"Top 10: {top10 / samples:.4g}"
            )
Casper's avatar
Casper committed
110
111
112
113

    z = student_t.ppf(1 - alpha / 2, samples)
    m_conf = z * np.sqrt(np.mean([k**2 for k in kls]) / len(kls))
    m, _, __ = bayes_mvs(kls, 1 - alpha)
Casper's avatar
Casper committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    q90 = np.quantile(kls, 0.90)
    q95 = np.quantile(kls, 0.95)
    q99 = np.quantile(kls, 0.99)
    q_bounds = mquantiles_cimj(kls, prob=[0.90, 0.95, 0.99])

    print(" -- ")
    print(" ** Reference model:", ref_model.config.model_type)
    print(" ** Evaluation model:", eval_model.config.model_type)
    print(" -- ")
    print(f" ** KL Divergence: {m[0]:.6g}, [{m[1][0]:.6g} - {m[1][1]:.6g}]")
    print(f" ** q90: {q90:.4g}, [{q_bounds[0][0]:.4g} - {q_bounds[1][0]:.4g}]")
    print(f" ** q95: {q95:.4g}, [{q_bounds[0][1]:.4g} - {q_bounds[1][1]:.4g}]")
    print(f" ** q99: {q99:.4g}, [{q_bounds[0][2]:.4g} - {q_bounds[1][2]:.4g}]")
    print(f"max: {np.max(kls):.4g}")
    print(" -- ")
    print("Reference top token in eval top-n probability:")
Casper's avatar
Casper committed
130
131
132
133
134
135
136
137
138
    print(
        f" ** ref_top1: {top1 / samples:.4g} ± {bin_conf(top1/samples, samples, z):.4g}"
    )
    print(
        f" ** ref_top5: {top5 / samples:.4g} ± {bin_conf(top5/samples, samples, z):.4g}"
    )
    print(
        f" ** ref_top10: {top10 / samples:4g} ± {bin_conf(top10/samples, samples, z):.4g}"
    )
Casper's avatar
Casper committed
139
    print("Eval top token in reference top-n probability:")
Casper's avatar
Casper committed
140
141
142
143
144
145
146
    print(
        f" ** eval_top5: {eval_top5 / samples:.4g} ± {bin_conf(eval_top5/samples, samples, z):.4g}"
    )
    print(
        f" ** eval_top10: {eval_top10 / samples:4g} ± {bin_conf(eval_top10/samples, samples, z):.4g}"
    )

Casper's avatar
Casper committed
147

Casper's avatar
Casper committed
148
if __name__ == "__main__":
Casper's avatar
Casper committed
149
150
151
152
153
154
    # ref_model_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
    # eval_model_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T"
    ref_model_path = eval_model_path = "gpt2"

    tokenizer = AutoTokenizer.from_pretrained(ref_model_path)
    ref_model = AutoModelForCausalLM.from_pretrained(ref_model_path, device_map="auto")
Casper's avatar
Casper committed
155
156
157
    eval_model = AutoModelForCausalLM.from_pretrained(
        eval_model_path, device_map="auto"
    )
Casper's avatar
Casper committed
158

Casper's avatar
Casper committed
159
    eval_kl_divergence(ref_model, eval_model, tokenizer, seqlen=1024)