Unverified Commit 4941a8bb authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge pull request #1 from cjlovering/cjlovering/gen_max_len

Add optional max length to generation.
parents 02ec7889 9384ec91
import abc
from typing import Iterable
from typing import Iterable, Optional
import promptsource
import numpy as np
......@@ -348,17 +348,25 @@ class BaseLM(LM):
for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str):
until = [until]
max_length = None
elif isinstance(until, list) and len(until) == 2:
until, max_length = [until[0]], until[1]
elif isinstance(until, list):
max_length = None
# TODO: Come back to for generation `eos`.
primary_until = self.tok_encode(until[0])
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
if max_length is not None:
max_length = min(max_length, context_enc.shape[1] + self.max_gen_toks)
else:
max_length = context_enc.shape[1] + self.max_gen_toks
cont = self._model_generate(
context_enc,
context_enc.shape[1] + self.max_gen_toks,
max_length,
torch.tensor(primary_until),
)
......@@ -652,7 +660,7 @@ class PromptSourceTask(Task):
super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt
def stopping_criteria(self):
def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end.
For example, for coqa, this is '\nQ:' and for drop '.'.
......@@ -661,6 +669,10 @@ class PromptSourceTask(Task):
"""
return None
def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task."""
return None
def is_generation_task(self):
return (
"BLEU" in self.prompt.metadata.metrics
......@@ -718,7 +730,9 @@ class PromptSourceTask(Task):
_requests.append(ll_answer_choice)
else:
# TODO(Albert): What is the stop symbol? Is it model specific?
cont_request = rf.greedy_until(ctx, [self.stopping_criteria()])
cont_request = rf.greedy_until(
ctx, [self.stopping_criteria(), self.max_generation_length()]
)
_requests.append(cont_request)
return _requests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment