Commit 6ec93da2 authored by jon-tow's avatar jon-tow
Browse files

Add `eos_token` property

parent 34f591af
...@@ -348,7 +348,8 @@ class BaseLM(LM): ...@@ -348,7 +348,8 @@ class BaseLM(LM):
if isinstance(until, str): if isinstance(until, str):
until = [until] until = [until]
(primary_until,) = self.tok_encode(until[0]) # TODO: Come back to for generation `eos`.
primary_until = self.tok_encode(until[0])[0]
context_enc = torch.tensor( context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]] [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
...@@ -616,7 +617,6 @@ class Task(abc.ABC): ...@@ -616,7 +617,6 @@ class Task(abc.ABC):
) )
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot # get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
...@@ -639,6 +639,9 @@ class PromptSourceTask(Task): ...@@ -639,6 +639,9 @@ class PromptSourceTask(Task):
super().__init__(data_dir, cache_dir, download_mode) super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt self.prompt = prompt
def eos_token(self):
raise NotImplementedError()
def doc_to_target(self, doc): def doc_to_target(self, doc):
_, target = self.prompt.apply(doc) _, target = self.prompt.apply(doc)
return f" {target}" return f" {target}"
...@@ -659,7 +662,6 @@ class PromptSourceTask(Task): ...@@ -659,7 +662,6 @@ class PromptSourceTask(Task):
part of the document for `doc`. part of the document for `doc`.
""" """
_requests = [] _requests = []
answer_choices_list = self.prompt.get_answer_choices_list(doc) answer_choices_list = self.prompt.get_answer_choices_list(doc)
if answer_choices_list: if answer_choices_list:
for answer_choice in answer_choices_list: for answer_choice in answer_choices_list:
...@@ -667,8 +669,8 @@ class PromptSourceTask(Task): ...@@ -667,8 +669,8 @@ class PromptSourceTask(Task):
_requests.append(ll_answer_choice) _requests.append(ll_answer_choice)
else: else:
# TODO(Albert): What is the stop symbol? Is it model specific? # TODO(Albert): What is the stop symbol? Is it model specific?
ll_greedy = rf.greedy_until(ctx, ["\nQ:"]) cont_request = rf.greedy_until(ctx, [self.eos_token()])
_requests.append(ll_greedy) _requests.append(cont_request)
return _requests return _requests
...@@ -694,6 +696,7 @@ class PromptSourceTask(Task): ...@@ -694,6 +696,7 @@ class PromptSourceTask(Task):
} }
else: else:
continuation = results continuation = results
raise NotImplementedError()
# Map metric name to HF metric. # Map metric name to HF metric.
# TODO(Albert): What is Other? # TODO(Albert): What is Other?
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment