Unverified Commit 065bb947 authored by Jeffrey Fong's avatar Jeffrey Fong Committed by GitHub
Browse files

Fix RuntimeEndpoint.select method (#1495)

parent f42e9bfb
...@@ -244,7 +244,8 @@ class RuntimeEndpoint(BaseBackend): ...@@ -244,7 +244,8 @@ class RuntimeEndpoint(BaseBackend):
"temperature": 0, "temperature": 0,
}, },
"return_logprob": True, "return_logprob": True,
"logprob_start_len": max(prompt_len - 2, 0), # for token healing "return_text_in_logprobs": True,
"logprob_start_len": prompt_len - 2, # For token healing
} }
obj = self._generate_http_request(s, data) obj = self._generate_http_request(s, data)
...@@ -254,6 +255,17 @@ class RuntimeEndpoint(BaseBackend): ...@@ -254,6 +255,17 @@ class RuntimeEndpoint(BaseBackend):
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj] input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj] output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
# Remove extra token if no token healing occurred
for i in range(len(input_token_logprobs)):
healed_token_str = input_token_logprobs[i][0][-1]
healed_token_logprob = input_token_logprobs[i][0][0]
if s.text_.endswith(healed_token_str):
normalized_prompt_logprobs[i] = (
normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
- healed_token_logprob
) / (len(input_token_logprobs[i]) - 1)
input_token_logprobs[i] = input_token_logprobs[i][1:]
# Compute unconditional logprobs if required # Compute unconditional logprobs if required
if choices_method.requires_unconditional_logprobs: if choices_method.requires_unconditional_logprobs:
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs] input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment