Commit 29bff88d authored by Tian Yun's avatar Tian Yun
Browse files

Updated T0 and GPT-J

parent b62d1bec
......@@ -8,6 +8,7 @@ class GPTJLM(BaseLM):
self,
device="cuda",
batch_size=1,
parallelize=False,
):
super().__init__()
......@@ -35,9 +36,11 @@ class GPTJLM(BaseLM):
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1:
# self.gptj = nn.DataParallel(self.gptj)
if parallelize:
self.gptj.parallelize()
self._device = torch.device('cuda:0')
else:
self.gptj.to(self._device)
@property
def eot_token(self):
......@@ -113,11 +116,23 @@ class GPTJLM(BaseLM):
EOSCriteria(self.tokenizer.eos_token)
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gptj.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
if num_fewshot == 0:
generations = self.gptj.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.gptj.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
# Remove the context from the generations
return generations[0, context.shape[1] :]
......@@ -56,7 +56,7 @@ class T0LM(BaseLM):
@property
def max_gen_toks(self):
return self.tokenizer.model_max_length
return 256
@property
def batch_size(self):
......@@ -94,6 +94,14 @@ class T0LM(BaseLM):
inputs, targets = zip(*chunk)
# Fill in empty encoder inputs with eos_token
inputs = (
f"{self.eot_token}"
if len(input_) == 0
else input_
for input_ in inputs
)
inputs_tok = self.tokenizer(
list(inputs),
max_length=self.max_length,
......@@ -172,11 +180,21 @@ class T0LM(BaseLM):
EOSCriteria(self.tokenizer.eos_token)
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t0.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
if num_fewshot == 0:
generations = self.t0.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.t0.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
return generations[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment