Commit 6c86eb47 authored by Tian Yun's avatar Tian Yun
Browse files

Fixed generation trunction for GPT-2 and T5

parent fce17ee1
......@@ -384,7 +384,7 @@ class BaseLM(LM):
torch.tensor(primary_until),
)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
s = self.tok_decode(cont.tolist())
for term in until:
s = s.split(term)[0]
......
......@@ -12,6 +12,7 @@ class HFLM(BaseLM):
subfolder=None,
tokenizer=None,
batch_size=1,
parallelize=False
):
super().__init__()
......@@ -32,7 +33,7 @@ class HFLM(BaseLM):
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
).to(self.device)
)
self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
......@@ -68,9 +69,11 @@ class HFLM(BaseLM):
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
if parallelize:
self.gpt2.parallelize()
self._device = torch.device('cuda:0')
else:
self.gpt2.to(self._device)
@property
def eot_token(self):
......@@ -147,15 +150,16 @@ class HFLM(BaseLM):
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gpt2.generate(
# stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
generations = self.gpt2.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
# stopping_criteria=stopping_criteria,
do_sample=False,
)
# Remove the context from the generations
return generations[0, context.shape[1] :]
# for backwards compatibility
GPT2LM = HFLM
......@@ -62,7 +62,7 @@ class T5LM(BaseLM):
@property
def max_gen_toks(self):
return self.tokenizer.model_max_length
return 256
@property
def batch_size(self):
......@@ -187,10 +187,10 @@ class T5LM(BaseLM):
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
# stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t5.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
# stopping_criteria=stopping_criteria,
do_sample=False,
)
)[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment