"example/vscode:/vscode.git/clone" did not exist on "457c024d608c9b855775cc014a630c7e0d30710c"
Commit 96ea7ddc authored by Tian Yun's avatar Tian Yun
Browse files

Added stoppping criteria for generation

parent c27e29e1
...@@ -116,10 +116,41 @@ class HFLM(BaseLM): ...@@ -116,10 +116,41 @@ class HFLM(BaseLM):
with torch.no_grad(): with torch.no_grad():
return self.gpt2(inps)[0][:, :, :50257] return self.gpt2(inps)[0][:, :, :50257]
def _model_generate(self, context, max_length, eos_token_id): def _get_stopping_criteria(self, stopping_criteria_ids):
class MultitokenEOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_seq_id: torch.LongTensor, tokenizer):
self.eos_seq = tokenizer.decode(eos_seq_id)
self.eos_seq_id = eos_seq_id
self.eos_seq_len = len(eos_seq_id) + 1
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
last_token_id = input_ids[0, -self.eos_seq_len:]
last_tokens = self.tokenizer.decode(last_token_id)
is_stopped = self.eos_seq in last_tokens
return is_stopped
class EOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_token_id: torch.LongTensor):
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids[0,-1] == self.eos_token_id
return transformers.StoppingCriteriaList([
MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer),
EOSCriteria(stopping_criteria_ids)
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gpt2.generate( return self.gpt2.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
) )
# for backwards compatibility # for backwards compatibility
......
import random
import lm_eval.models as models
import pytest
import torch
from transformers import StoppingCriteria
@pytest.mark.parametrize(
"eos_token,test_input,expected",
[
("not", "i like", "i like to say that I'm not"),
("say that", "i like", "i like to say that"),
("great", "big science is", "big science is a great"),
("<|endoftext|>", "big science has", "big science has been done in the past, but it's not the same as the science of the")
]
)
def test_stopping_criteria(eos_token, test_input, expected):
random.seed(42)
torch.random.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
gpt2 = models.get_model("gpt2")(device=device)
context = torch.tensor([gpt2.tokenizer.encode(test_input)])
stopping_criteria_ids = gpt2.tokenizer.encode(eos_token)
generations = gpt2._model_generate(
context,
max_length=20,
stopping_criteria_ids=stopping_criteria_ids
)
generations = gpt2.tokenizer.decode(generations[0])
assert generations == expected
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment