Commit 7da753f2 authored by Tian Yun's avatar Tian Yun
Browse files

Adding null prompt support in T5

parent 20820c3c
...@@ -100,6 +100,14 @@ class T5LM(BaseLM): ...@@ -100,6 +100,14 @@ class T5LM(BaseLM):
inputs, targets = zip(*chunk) inputs, targets = zip(*chunk)
# Fill in empty encoder inputs with eos_token
inputs = (
f"{self.eot_token}"
if len(input_) == 0
else input_
for input_ in inputs
)
inputs_tok = self.tokenizer( inputs_tok = self.tokenizer(
list(inputs), list(inputs),
max_length=self.max_length, max_length=self.max_length,
...@@ -123,7 +131,7 @@ class T5LM(BaseLM): ...@@ -123,7 +131,7 @@ class T5LM(BaseLM):
for key in targets_tok: for key in targets_tok:
targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :] targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :]
outputs = self._model_call(inputs_tok, targets_tok) outputs = self._model_call(inputs_tok, targets_tok)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1) log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment