"vscode:/vscode.git/clone" did not exist on "c26507484fca9c6a901754b16af56285df29aa2b"
cot_decoding.py 3.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from math import exp
from pprint import pformat

import sglang as sgl

YELLOW = "\033[1;33m"
GREEN = "\033[1;32m"
BLUE = "\033[1;34m"
CLEAR = "\033[1;0m"


@sgl.function
def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
    """CoT Decoding: http://arxiv.org/abs/2402.10200"""

    if is_chat_model:
        s += sgl.user("Question: " + question + "\nAnswer:")
        s += sgl.assistant_begin()
    else:
        s += "Question: " + question + "\nAnswer:"

    step_0 = s.fork(1)[0]
    forks = s.fork(get_top_k)
    answer_forks = s.fork(get_top_k)

    # decoding step 0
    step_0 += sgl.gen(
        "get_top_k",
        max_tokens=0,
        return_logprob=True,
        top_logprobs_num=get_top_k,
        return_text_in_logprobs=True,
    )
    logprobs = step_0.get_meta_info("get_top_k")["decode_top_logprobs"][0]

zhyncs's avatar
zhyncs committed
36
    print("Decoding step 0:", ", ".join(pformat(token[2]) for token in logprobs))
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    for idx, (f, token) in enumerate(zip(forks, logprobs)):
        logprob, token_id, text = token
        f += text

        if text == "<|end_of_text|>":
            print(
                f"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score=nan, answer=nan){CLEAR}"
            )
            continue

        # continue greedy decoding
        f += sgl.gen(
            "answer",
            temperature=0,
            max_tokens=1024,
            return_logprob=True,
            top_logprobs_num=2,
            return_text_in_logprobs=True,
        )

        # calculate probability disparity between the top and secondary tokens
zhyncs's avatar
zhyncs committed
58
59
60
        x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
        x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
        tokens = [xt[0][2] for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
61
62
63
64
65
66
67
68
69
70
71
72
        delta = (sum(x1s) - sum(x2s)) / len(x1s)

        # extract the answer span (without the '<|end_of_text|>' token)
        answer_forks[idx] += text + f["answer"] + "\nSo the answer is"
        answer_forks[idx] += sgl.gen(
            "answer_span",
            temperature=0,
            max_tokens=64,
            return_logprob=True,
            top_logprobs_num=2,
            return_text_in_logprobs=True,
        )
zhyncs's avatar
zhyncs committed
73
        answer = answer_forks[idx]["answer_span"].replace("\n", " ").strip(":")
74
75
76
        print(
            f"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score={delta}, answer={answer}){CLEAR}"
        )
zhyncs's avatar
zhyncs committed
77
        generated_text = str(answer_forks[idx])[len("ProgramState(") : -1]
78
79
80
81
        print(f"{BLUE}{pformat(generated_text)}{CLEAR}")

        if verbose:
            answer_tokens = [
zhyncs's avatar
zhyncs committed
82
83
84
85
                xt[0][2]
                for xt in answer_forks[idx].get_meta_info("answer_span")[
                    "decode_top_logprobs"
                ]
86
87
            ]
            answer_x1s = [
zhyncs's avatar
zhyncs committed
88
89
90
91
                exp(xt[0][0])
                for xt in answer_forks[idx].get_meta_info("answer_span")[
                    "decode_top_logprobs"
                ]
92
93
            ]
            answer_x2s = [
zhyncs's avatar
zhyncs committed
94
95
96
97
                exp(xt[1][0])
                for xt in answer_forks[idx].get_meta_info("answer_span")[
                    "decode_top_logprobs"
                ]
98
99
100
            ]

            for token, x1, x2 in zip(tokens, x1s, x2s):
zhyncs's avatar
zhyncs committed
101
                print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})", end="")
102
103
            print("\n===========")
            for token, x1, x2 in zip(answer_tokens, answer_x1s, answer_x2s):
zhyncs's avatar
zhyncs committed
104
                print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})", end="")
105
106
107
108
109
110
            print()


sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))

state = cot_decoding.run(
zhyncs's avatar
zhyncs committed
111
    question=r"Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4  weeks?",
112
113
114
115
    get_top_k=10,
    is_chat_model=True,
    verbose=False,
)