Commit 6c86eb47 authored by Tian Yun's avatar Tian Yun
Browse files

Fixed generation trunction for GPT-2 and T5

parent fce17ee1
...@@ -384,7 +384,7 @@ class BaseLM(LM): ...@@ -384,7 +384,7 @@ class BaseLM(LM):
torch.tensor(primary_until), torch.tensor(primary_until),
) )
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :]) s = self.tok_decode(cont.tolist())
for term in until: for term in until:
s = s.split(term)[0] s = s.split(term)[0]
......
...@@ -12,6 +12,7 @@ class HFLM(BaseLM): ...@@ -12,6 +12,7 @@ class HFLM(BaseLM):
subfolder=None, subfolder=None,
tokenizer=None, tokenizer=None,
batch_size=1, batch_size=1,
parallelize=False
): ):
super().__init__() super().__init__()
...@@ -32,7 +33,7 @@ class HFLM(BaseLM): ...@@ -32,7 +33,7 @@ class HFLM(BaseLM):
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""), revision=revision + ("/" + subfolder if subfolder is not None else ""),
).to(self.device) )
self.gpt2.eval() self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2 # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
...@@ -68,9 +69,11 @@ class HFLM(BaseLM): ...@@ -68,9 +69,11 @@ class HFLM(BaseLM):
self.batch_size_per_gpu = batch_size # todo: adaptive batch size self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu # TODO: fix multi-gpu
# gpus = torch.cuda.device_count() if parallelize:
# if gpus > 1: self.gpt2.parallelize()
# self.gpt2 = nn.DataParallel(self.gpt2) self._device = torch.device('cuda:0')
else:
self.gpt2.to(self._device)
@property @property
def eot_token(self): def eot_token(self):
...@@ -147,15 +150,16 @@ class HFLM(BaseLM): ...@@ -147,15 +150,16 @@ class HFLM(BaseLM):
]) ])
def _model_generate(self, context, max_length, stopping_criteria_ids): def _model_generate(self, context, max_length, stopping_criteria_ids):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) # stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gpt2.generate( generations = self.gpt2.generate(
context, context,
max_length=max_length, max_length=max_length,
stopping_criteria=stopping_criteria, # stopping_criteria=stopping_criteria,
do_sample=False, do_sample=False,
) )
# Remove the context from the generations
return generations[0, context.shape[1] :]
# for backwards compatibility # for backwards compatibility
GPT2LM = HFLM GPT2LM = HFLM
...@@ -62,7 +62,7 @@ class T5LM(BaseLM): ...@@ -62,7 +62,7 @@ class T5LM(BaseLM):
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return self.tokenizer.model_max_length return 256
@property @property
def batch_size(self): def batch_size(self):
...@@ -187,10 +187,10 @@ class T5LM(BaseLM): ...@@ -187,10 +187,10 @@ class T5LM(BaseLM):
]) ])
def _model_generate(self, context, max_length, stopping_criteria_ids): def _model_generate(self, context, max_length, stopping_criteria_ids):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) # stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t5.generate( return self.t5.generate(
context, context,
max_length=max_length, max_length=max_length,
stopping_criteria=stopping_criteria, # stopping_criteria=stopping_criteria,
do_sample=False, do_sample=False,
) )[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment