Unverified Commit 4941a8bb authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge pull request #1 from cjlovering/cjlovering/gen_max_len

Add optional max length to generation.
parents 02ec7889 9384ec91
import abc import abc
from typing import Iterable from typing import Iterable, Optional
import promptsource import promptsource
import numpy as np import numpy as np
...@@ -348,17 +348,25 @@ class BaseLM(LM): ...@@ -348,17 +348,25 @@ class BaseLM(LM):
for context, until in tqdm(reord.get_reordered()): for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str): if isinstance(until, str):
until = [until] until = [until]
max_length = None
elif isinstance(until, list) and len(until) == 2:
until, max_length = [until[0]], until[1]
elif isinstance(until, list):
max_length = None
# TODO: Come back to for generation `eos`.
primary_until = self.tok_encode(until[0]) primary_until = self.tok_encode(until[0])
context_enc = torch.tensor( context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]] [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device) ).to(self.device)
if max_length is not None:
max_length = min(max_length, context_enc.shape[1] + self.max_gen_toks)
else:
max_length = context_enc.shape[1] + self.max_gen_toks
cont = self._model_generate( cont = self._model_generate(
context_enc, context_enc,
context_enc.shape[1] + self.max_gen_toks, max_length,
torch.tensor(primary_until), torch.tensor(primary_until),
) )
...@@ -652,7 +660,7 @@ class PromptSourceTask(Task): ...@@ -652,7 +660,7 @@ class PromptSourceTask(Task):
super().__init__(data_dir, cache_dir, download_mode) super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt self.prompt = prompt
def stopping_criteria(self): def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end. """Denote where the generation should end.
For example, for coqa, this is '\nQ:' and for drop '.'. For example, for coqa, this is '\nQ:' and for drop '.'.
...@@ -661,6 +669,10 @@ class PromptSourceTask(Task): ...@@ -661,6 +669,10 @@ class PromptSourceTask(Task):
""" """
return None return None
def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task."""
return None
def is_generation_task(self): def is_generation_task(self):
return ( return (
"BLEU" in self.prompt.metadata.metrics "BLEU" in self.prompt.metadata.metrics
...@@ -718,7 +730,9 @@ class PromptSourceTask(Task): ...@@ -718,7 +730,9 @@ class PromptSourceTask(Task):
_requests.append(ll_answer_choice) _requests.append(ll_answer_choice)
else: else:
# TODO(Albert): What is the stop symbol? Is it model specific? # TODO(Albert): What is the stop symbol? Is it model specific?
cont_request = rf.greedy_until(ctx, [self.stopping_criteria()]) cont_request = rf.greedy_until(
ctx, [self.stopping_criteria(), self.max_generation_length()]
)
_requests.append(cont_request) _requests.append(cont_request)
return _requests return _requests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment