Commit 6b453bfd authored by uyhcire's avatar uyhcire
Browse files

Tweak StoryCloze script to be agnostic to tokenization

parent 4dbde45a
......@@ -89,11 +89,6 @@ def evaluate_example(model, tokenizer, example):
def compute_per_token_logit_for_completion(model, tokenizer, prompt, completion):
prompt_token_count = (
tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
.to("cuda")
.shape[1]
)
encoded_prompt_with_completion = tokenizer.encode(
prompt + " " + completion,
add_special_tokens=False,
......@@ -114,15 +109,7 @@ def compute_per_token_logit_for_completion(model, tokenizer, prompt, completion)
input_tokens_at_positions_with_logits.unsqueeze(1),
).squeeze(1)
return (
logits_for_provided_tokens[
prompt_token_count
# Again, the model does not predict the first input token, so we need
- 1 :
]
.mean()
.item()
)
return logits_for_provided_tokens.mean().item()
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment