Commit 6ec93da2 authored by jon-tow's avatar jon-tow
Browse files

Add `eos_token` property

parent 34f591af
......@@ -348,7 +348,8 @@ class BaseLM(LM):
if isinstance(until, str):
until = [until]
(primary_until,) = self.tok_encode(until[0])
# TODO: Come back to for generation `eos`.
primary_until = self.tok_encode(until[0])[0]
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
......@@ -616,7 +617,6 @@ class Task(abc.ABC):
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
......@@ -639,6 +639,9 @@ class PromptSourceTask(Task):
super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt
def eos_token(self):
raise NotImplementedError()
def doc_to_target(self, doc):
_, target = self.prompt.apply(doc)
return f" {target}"
......@@ -659,7 +662,6 @@ class PromptSourceTask(Task):
part of the document for `doc`.
"""
_requests = []
answer_choices_list = self.prompt.get_answer_choices_list(doc)
if answer_choices_list:
for answer_choice in answer_choices_list:
......@@ -667,8 +669,8 @@ class PromptSourceTask(Task):
_requests.append(ll_answer_choice)
else:
# TODO(Albert): What is the stop symbol? Is it model specific?
ll_greedy = rf.greedy_until(ctx, ["\nQ:"])
_requests.append(ll_greedy)
cont_request = rf.greedy_until(ctx, [self.eos_token()])
_requests.append(cont_request)
return _requests
......@@ -694,6 +696,7 @@ class PromptSourceTask(Task):
}
else:
continuation = results
raise NotImplementedError()
# Map metric name to HF metric.
# TODO(Albert): What is Other?
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment