Unverified Commit 7c011370 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Get rid of get_token_logprobs

parent ae63eeb7
......@@ -107,27 +107,6 @@ class GPT3LM(BaseLM):
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def get_token_logprobs(self, input_tokens, pred_tokens):
pred_start = len(input_tokens) - len(pred_tokens) + 1
# We're going to stitch together the input_tokens and pred_tokens
# In the longest case, this gets us to length = max_seq_len+1 (which the API works with)
assert input_tokens[pred_start:] == pred_tokens[:-1]
token_ids = input_tokens + [pred_tokens[-1]]
response = oa_completion(
engine=self.engine,
prompt=token_ids,
max_tokens=0,
temperature=0.0,
logprobs=0,
echo=True,
)
logprobs = np.array(response["choices"][0]["logprobs"]["token_logprobs"][pred_start:])
positions = np.arange(pred_start-1, pred_start-1 + len(token_ids[pred_start:]))
return {
"logprobs": logprobs,
"positions": positions,
}
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
res = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment