"include/vscode:/vscode.git/clone" did not exist on "bec0399aa9da409e16562574725f16a7de732791"
Commit 29bff88d authored by Tian Yun's avatar Tian Yun
Browse files

Updated T0 and GPT-J

parent b62d1bec
...@@ -8,6 +8,7 @@ class GPTJLM(BaseLM): ...@@ -8,6 +8,7 @@ class GPTJLM(BaseLM):
self, self,
device="cuda", device="cuda",
batch_size=1, batch_size=1,
parallelize=False,
): ):
super().__init__() super().__init__()
...@@ -35,9 +36,11 @@ class GPTJLM(BaseLM): ...@@ -35,9 +36,11 @@ class GPTJLM(BaseLM):
self.batch_size_per_gpu = batch_size # todo: adaptive batch size self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu # TODO: fix multi-gpu
# gpus = torch.cuda.device_count() if parallelize:
# if gpus > 1: self.gptj.parallelize()
# self.gptj = nn.DataParallel(self.gptj) self._device = torch.device('cuda:0')
else:
self.gptj.to(self._device)
@property @property
def eot_token(self): def eot_token(self):
...@@ -113,11 +116,23 @@ class GPTJLM(BaseLM): ...@@ -113,11 +116,23 @@ class GPTJLM(BaseLM):
EOSCriteria(self.tokenizer.eos_token) EOSCriteria(self.tokenizer.eos_token)
]) ])
def _model_generate(self, context, max_length, stopping_criteria_ids): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gptj.generate(
context, if num_fewshot == 0:
max_length=max_length, generations = self.gptj.generate(
stopping_criteria=stopping_criteria, context,
do_sample=False, max_length=max_length,
) eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.gptj.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
# Remove the context from the generations
return generations[0, context.shape[1] :]
...@@ -56,7 +56,7 @@ class T0LM(BaseLM): ...@@ -56,7 +56,7 @@ class T0LM(BaseLM):
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return self.tokenizer.model_max_length return 256
@property @property
def batch_size(self): def batch_size(self):
...@@ -94,6 +94,14 @@ class T0LM(BaseLM): ...@@ -94,6 +94,14 @@ class T0LM(BaseLM):
inputs, targets = zip(*chunk) inputs, targets = zip(*chunk)
# Fill in empty encoder inputs with eos_token
inputs = (
f"{self.eot_token}"
if len(input_) == 0
else input_
for input_ in inputs
)
inputs_tok = self.tokenizer( inputs_tok = self.tokenizer(
list(inputs), list(inputs),
max_length=self.max_length, max_length=self.max_length,
...@@ -172,11 +180,21 @@ class T0LM(BaseLM): ...@@ -172,11 +180,21 @@ class T0LM(BaseLM):
EOSCriteria(self.tokenizer.eos_token) EOSCriteria(self.tokenizer.eos_token)
]) ])
def _model_generate(self, context, max_length, stopping_criteria_ids): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t0.generate(
context, if num_fewshot == 0:
max_length=max_length, generations = self.t0.generate(
stopping_criteria=stopping_criteria, context,
do_sample=False, max_length=max_length,
) eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.t0.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
return generations[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment