Commit 5e59320b authored by Tian Yun's avatar Tian Yun
Browse files

Modified stopping criteria for T5 and GPT-2

parent c46ff9e4
......@@ -353,11 +353,13 @@ class BaseLM(LM):
for context, request_args in tqdm(reord.get_reordered()):
stopping_criteria = request_args["stopping_criteria"]
max_generation_length = request_args["max_generation_length"]
num_fewshot = request_args["num_fewshot"]
assert isinstance(stopping_criteria, str) or stopping_criteria is None
assert (
isinstance(max_generation_length, int) or max_generation_length is None
)
assert isinstance(num_fewshot, int) or num_fewshot is None
if stopping_criteria is None:
until = [self.eot_token]
......@@ -382,6 +384,7 @@ class BaseLM(LM):
context_enc,
max_length,
torch.tensor(primary_until),
num_fewshot,
)
s = self.tok_decode(cont.tolist())
......@@ -536,7 +539,7 @@ class Task(abc.ABC):
pass
@abstractmethod
def construct_requests(self, doc, ctx):
def construct_requests(self, doc, ctx, args):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -546,6 +549,8 @@ class Task(abc.ABC):
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
"""
pass
......@@ -724,7 +729,7 @@ class PromptSourceTask(Task):
text, _ = self.prompt.apply(doc)
return text
def construct_requests(self, doc, ctx):
def construct_requests(self, doc, ctx, args):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -734,6 +739,8 @@ class PromptSourceTask(Task):
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
"""
_requests = []
answer_choices_list = self.prompt.get_answer_choices_list(doc)
......@@ -749,6 +756,7 @@ class PromptSourceTask(Task):
request_args = {
"stopping_criteria": self.stopping_criteria(),
"max_generation_length": self.max_generation_length(),
"num_fewshot": args["num_fewshot"],
}
cont_request = rf.greedy_until(ctx, request_args)
_requests.append(cont_request)
......
......@@ -206,7 +206,8 @@ def evaluate(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
fewshotex_logging_info["doc_id"] = original_doc_id
reqs = task.construct_requests(doc, ctx)
args = {"num_fewshot": num_fewshot}
reqs = task.construct_requests(doc, ctx, args)
if not isinstance(reqs, (list, tuple)):
reqs = [reqs]
for i, req in enumerate(reqs):
......
......@@ -149,14 +149,23 @@ class HFLM(BaseLM):
EOSCriteria(self.tokenizer.eos_token)
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
# stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
generations = self.gpt2.generate(
context,
max_length=max_length,
# stopping_criteria=stopping_criteria,
do_sample=False,
)
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
if num_fewshot == 0:
generations = self.gpt2.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.gpt2.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
# Remove the context from the generations
return generations[0, context.shape[1] :]
......
......@@ -186,11 +186,21 @@ class T5LM(BaseLM):
EOSCriteria(self.tokenizer.eos_token)
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
# stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t5.generate(
context,
max_length=max_length,
# stopping_criteria=stopping_criteria,
do_sample=False,
)[0]
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
if num_fewshot == 0:
generations = self.t5.generate(
context,
max_length=max_length,
eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.t5.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
return generations[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment