Commit ecde9d2e authored by Tian Yun's avatar Tian Yun
Browse files

Fixing stopping criteria

parent 1cd4ec01
...@@ -121,6 +121,11 @@ class LM(abc.ABC): ...@@ -121,6 +121,11 @@ class LM(abc.ABC):
class BaseLM(LM): class BaseLM(LM):
@property
@abstractmethod
def eot_token(self):
pass
@property @property
@abstractmethod @abstractmethod
def eot_token_id(self): def eot_token_id(self):
...@@ -354,8 +359,15 @@ class BaseLM(LM): ...@@ -354,8 +359,15 @@ class BaseLM(LM):
isinstance(max_generation_length, int) or max_generation_length is None isinstance(max_generation_length, int) or max_generation_length is None
) )
until = [stopping_criteria] if stopping_criteria is None:
until = [self.eot_token]
else:
until = [stopping_criteria]
primary_until = self.tok_encode(until[0]) primary_until = self.tok_encode(until[0])
if len(primary_until) == 0:
primary_until = torch.tensor([self.eot_token_id])
context_enc = torch.tensor( context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]] [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device) ).to(self.device)
......
...@@ -72,6 +72,10 @@ class HFLM(BaseLM): ...@@ -72,6 +72,10 @@ class HFLM(BaseLM):
# if gpus > 1: # if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2) # self.gpt2 = nn.DataParallel(self.gpt2)
@property
def eot_token(self):
return self.tokenizer.eos_token
@property @property
def eot_token_id(self): def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment