Unverified Commit 2e1b05d2 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #71 from EleutherAI/uyhcire-normalize-loglikelihoods

Fix eval script to normalize loglikelihoods
parents 870a247a 8315dce7
...@@ -4,6 +4,7 @@ import time ...@@ -4,6 +4,7 @@ import time
import click import click
import torch import torch
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
...@@ -53,18 +54,18 @@ def evaluate_examples(model_runner, examples): ...@@ -53,18 +54,18 @@ def evaluate_examples(model_runner, examples):
for prompt, example in zip(prompts, examples) for prompt, example in zip(prompts, examples)
] ]
average_token_logits_with_sentence_1 = ( average_token_loglikelihoods_with_sentence_1 = (
model_runner.compute_average_token_logits_on_batch(inputs_for_sentence_1) model_runner.compute_average_token_loglikelihoods_on_batch(inputs_for_sentence_1)
) )
average_token_logits_with_sentence_2 = ( average_token_loglikelihoods_with_sentence_2 = (
model_runner.compute_average_token_logits_on_batch(inputs_for_sentence_2) model_runner.compute_average_token_loglikelihoods_on_batch(inputs_for_sentence_2)
) )
evaluation_results = [] evaluation_results = []
for i in range(len(examples)): for i in range(len(examples)):
if ( if (
average_token_logits_with_sentence_1[i] average_token_loglikelihoods_with_sentence_1[i]
> average_token_logits_with_sentence_2[i] > average_token_loglikelihoods_with_sentence_2[i]
): ):
model_answer = examples[i]["RandomFifthSentenceQuiz1"] model_answer = examples[i]["RandomFifthSentenceQuiz1"]
model_answer_code = "1" model_answer_code = "1"
...@@ -96,15 +97,15 @@ class ModelRunner: ...@@ -96,15 +97,15 @@ class ModelRunner:
model_runner.model = AutoModelForCausalLM.from_pretrained( model_runner.model = AutoModelForCausalLM.from_pretrained(
# 117M # 117M
pretrained_model_name_or_path="gpt2", pretrained_model_name_or_path="gpt2-large",
config=AutoConfig.from_pretrained( config=AutoConfig.from_pretrained(
"gpt2", "gpt2-large",
# <|endoftext|> # <|endoftext|>
pad_token_id=50256, pad_token_id=50256,
), ),
).to("cuda") ).to("cuda")
model_runner.model = model_runner.model.eval() model_runner.model = model_runner.model.eval()
model_runner.tokenizer = AutoTokenizer.from_pretrained("gpt2") model_runner.tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
model_runner.tokenizer.pad_token = "<|endoftext|>" model_runner.tokenizer.pad_token = "<|endoftext|>"
prompt = "The quick brown fox jumps over" prompt = "The quick brown fox jumps over"
...@@ -126,11 +127,11 @@ class ModelRunner: ...@@ -126,11 +127,11 @@ class ModelRunner:
return model_runner return model_runner
def compute_average_token_logits_on_batch(self, input_texts): def compute_average_token_loglikelihoods_on_batch(self, input_texts):
""" """
For each input text in the batch, compute the average logit (log-likelihood) over all tokens. For each input text in the batch, compute the average log-likelihood over all tokens.
For example, if an input sequence is 3 tokens long, and the token logits are [-1, -2, -3], the "average token logit" is -2. For example, if an input sequence is 3 tokens long, and the token loglikelihoods are [-1, -2, -3], the "average token loglikelihood" is -2.
""" """
# The ModelRunner can take a big batch on input_texts, and it can be as large as the caller wants. # The ModelRunner can take a big batch on input_texts, and it can be as large as the caller wants.
# But to prevent the GPU from running out of memory, we need to subdivide the overall batch # But to prevent the GPU from running out of memory, we need to subdivide the overall batch
...@@ -138,16 +139,16 @@ class ModelRunner: ...@@ -138,16 +139,16 @@ class ModelRunner:
# For GPT-2-117M, a GPU can process a batch of roughly 10 or so inputs before the inference latency starts to increase. # For GPT-2-117M, a GPU can process a batch of roughly 10 or so inputs before the inference latency starts to increase.
gpu_batch_size = 20 gpu_batch_size = 20
average_token_logits = [] average_token_loglikelihoods = []
for i in range(0, len(input_texts), gpu_batch_size): for i in range(0, len(input_texts), gpu_batch_size):
average_token_logits.extend( average_token_loglikelihoods.extend(
self._average_token_logits_on_gpu_batch( self._average_token_loglikelihoods_on_gpu_batch(
input_texts[i : i + gpu_batch_size] input_texts[i : i + gpu_batch_size]
) )
) )
return average_token_logits return average_token_loglikelihoods
def _average_token_logits_on_gpu_batch(self, input_texts): def _average_token_loglikelihoods_on_gpu_batch(self, input_texts):
tokenized_inputs = self.tokenizer( tokenized_inputs = self.tokenizer(
input_texts, input_texts,
add_special_tokens=False, add_special_tokens=False,
...@@ -164,42 +165,49 @@ class ModelRunner: ...@@ -164,42 +165,49 @@ class ModelRunner:
output_logits = self.model(tokenized_inputs).logits output_logits = self.model(tokenized_inputs).logits
self.num_inferences += 1 self.num_inferences += 1
# Align the output logits to the input tokens. # Normalize probabilities - at each position, the token likelihoods should add up to 1
logits_for_input_positions = output_logits[ output_loglikelihoods = F.log_softmax(
output_logits,
# The embedding dimension
dim=-1,
)
# Align the output loglikelihoods to the input tokens.
loglikelihoods_for_input_positions = output_loglikelihoods[
# The batch dimension # The batch dimension
:, :,
# The position dimension # The position dimension
# The last logit needs to be dropped, because it's predicting the "next token", and it doesn't correspond to any input token # The last loglikelihood needs to be dropped, because it's predicting the "next token", and it doesn't correspond to any input token
:-1, :-1,
# The embedding dimension # The embedding dimension
:, :,
] ]
input_tokens_at_positions_with_logits = tokenized_inputs[ input_tokens_at_positions_with_loglikelihoods = tokenized_inputs[
# The batch dimension # The batch dimension
:, :,
# The position dimension # The position dimension
# The model does not predict the first input token, so the first token needs to be dropped. # The model does not predict the first input token, so the first token needs to be dropped.
1:, 1:,
] ]
# At each position, the model outputs ~50k logits, one for every possible token. # At each position, the model outputs ~50k loglikelihoods, one for every possible token.
# To get the logits of the tokens that were actually provided, we need to select the right logit at each position. # To get the loglikelihoods of the tokens that were actually provided, we need to select the right loglikelihood at each position.
logits_for_provided_tokens = torch.gather( loglikelihoods_for_provided_tokens = torch.gather(
logits_for_input_positions, loglikelihoods_for_input_positions,
2, 2,
input_tokens_at_positions_with_logits.unsqueeze(2), input_tokens_at_positions_with_loglikelihoods.unsqueeze(2),
).squeeze(2) ).squeeze(2)
mask_for_non_padded_positions = input_tokens_at_positions_with_logits != 50256 mask_for_non_padded_positions = input_tokens_at_positions_with_loglikelihoods != 50256
average_token_logits = ( average_token_loglikelihoods = (
logits_for_provided_tokens * mask_for_non_padded_positions loglikelihoods_for_provided_tokens * mask_for_non_padded_positions
).sum(1) / mask_for_non_padded_positions.sum(1) ).sum(1) / mask_for_non_padded_positions.sum(1)
average_token_logits = average_token_logits.tolist() average_token_loglikelihoods = average_token_loglikelihoods.tolist()
end_time = time.time() end_time = time.time()
print( print(
f"Time to evaluate once (inference #{self.num_inferences}): {end_time - start_time}" f"Time to evaluate once (inference #{self.num_inferences}): {end_time - start_time}"
) )
return average_token_logits return average_token_loglikelihoods
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment