cli-logprob.py 508 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
prompt = "The capital of taiwan is "

import json

import requests

response = requests.post(
    "http://0.0.0.0:8000/generate",
    json={
        "text": prompt,
        "sampling_params": {"temperature": 0},
        "return_logprob": True,
        "return_input_logprob": True,
        "logprob_start_len": 0,
    },
)

j = response.json()
input_logprobs = j["meta_info"]["input_token_logprobs"]
output_logprobs = j["meta_info"]["output_token_logprobs"]

print(len(input_logprobs), len(output_logprobs))