Commit 9384ec91 authored by cjlovering's avatar cjlovering
Browse files

A dependency required this but it was not installed by default

parent 02ec7889
import abc import abc
from typing import Iterable from typing import Iterable, Optional
import promptsource import promptsource
import numpy as np import numpy as np
...@@ -348,17 +348,25 @@ class BaseLM(LM): ...@@ -348,17 +348,25 @@ class BaseLM(LM):
for context, until in tqdm(reord.get_reordered()): for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str): if isinstance(until, str):
until = [until] until = [until]
max_length = None
elif isinstance(until, list) and len(until) == 2:
until, max_length = [until[0]], until[1]
elif isinstance(until, list):
max_length = None
# TODO: Come back to for generation `eos`.
primary_until = self.tok_encode(until[0]) primary_until = self.tok_encode(until[0])
context_enc = torch.tensor( context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]] [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device) ).to(self.device)
if max_length is not None:
max_length = min(max_length, context_enc.shape[1] + self.max_gen_toks)
else:
max_length = context_enc.shape[1] + self.max_gen_toks
cont = self._model_generate( cont = self._model_generate(
context_enc, context_enc,
context_enc.shape[1] + self.max_gen_toks, max_length,
torch.tensor(primary_until), torch.tensor(primary_until),
) )
...@@ -652,7 +660,7 @@ class PromptSourceTask(Task): ...@@ -652,7 +660,7 @@ class PromptSourceTask(Task):
super().__init__(data_dir, cache_dir, download_mode) super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt self.prompt = prompt
def stopping_criteria(self): def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end. """Denote where the generation should end.
For example, for coqa, this is '\nQ:' and for drop '.'. For example, for coqa, this is '\nQ:' and for drop '.'.
...@@ -661,6 +669,10 @@ class PromptSourceTask(Task): ...@@ -661,6 +669,10 @@ class PromptSourceTask(Task):
""" """
return None return None
def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task."""
return None
def is_generation_task(self): def is_generation_task(self):
return ( return (
"BLEU" in self.prompt.metadata.metrics "BLEU" in self.prompt.metadata.metrics
...@@ -718,7 +730,9 @@ class PromptSourceTask(Task): ...@@ -718,7 +730,9 @@ class PromptSourceTask(Task):
_requests.append(ll_answer_choice) _requests.append(ll_answer_choice)
else: else:
# TODO(Albert): What is the stop symbol? Is it model specific? # TODO(Albert): What is the stop symbol? Is it model specific?
cont_request = rf.greedy_until(ctx, [self.stopping_criteria()]) cont_request = rf.greedy_until(
ctx, [self.stopping_criteria(), self.max_generation_length()]
)
_requests.append(cont_request) _requests.append(cont_request)
return _requests return _requests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment