"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "32c30c3bf9c10189a0e3bae6327a5085a05a5143"
Commit 5e59320b authored by Tian Yun's avatar Tian Yun
Browse files

Modified stopping criteria for T5 and GPT-2

parent c46ff9e4
...@@ -353,11 +353,13 @@ class BaseLM(LM): ...@@ -353,11 +353,13 @@ class BaseLM(LM):
for context, request_args in tqdm(reord.get_reordered()): for context, request_args in tqdm(reord.get_reordered()):
stopping_criteria = request_args["stopping_criteria"] stopping_criteria = request_args["stopping_criteria"]
max_generation_length = request_args["max_generation_length"] max_generation_length = request_args["max_generation_length"]
num_fewshot = request_args["num_fewshot"]
assert isinstance(stopping_criteria, str) or stopping_criteria is None assert isinstance(stopping_criteria, str) or stopping_criteria is None
assert ( assert (
isinstance(max_generation_length, int) or max_generation_length is None isinstance(max_generation_length, int) or max_generation_length is None
) )
assert isinstance(num_fewshot, int) or num_fewshot is None
if stopping_criteria is None: if stopping_criteria is None:
until = [self.eot_token] until = [self.eot_token]
...@@ -382,6 +384,7 @@ class BaseLM(LM): ...@@ -382,6 +384,7 @@ class BaseLM(LM):
context_enc, context_enc,
max_length, max_length,
torch.tensor(primary_until), torch.tensor(primary_until),
num_fewshot,
) )
s = self.tok_decode(cont.tolist()) s = self.tok_decode(cont.tolist())
...@@ -536,7 +539,7 @@ class Task(abc.ABC): ...@@ -536,7 +539,7 @@ class Task(abc.ABC):
pass pass
@abstractmethod @abstractmethod
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx, args):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -546,6 +549,8 @@ class Task(abc.ABC): ...@@ -546,6 +549,8 @@ class Task(abc.ABC):
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
""" """
pass pass
...@@ -724,7 +729,7 @@ class PromptSourceTask(Task): ...@@ -724,7 +729,7 @@ class PromptSourceTask(Task):
text, _ = self.prompt.apply(doc) text, _ = self.prompt.apply(doc)
return text return text
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx, args):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -734,6 +739,8 @@ class PromptSourceTask(Task): ...@@ -734,6 +739,8 @@ class PromptSourceTask(Task):
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
""" """
_requests = [] _requests = []
answer_choices_list = self.prompt.get_answer_choices_list(doc) answer_choices_list = self.prompt.get_answer_choices_list(doc)
...@@ -749,6 +756,7 @@ class PromptSourceTask(Task): ...@@ -749,6 +756,7 @@ class PromptSourceTask(Task):
request_args = { request_args = {
"stopping_criteria": self.stopping_criteria(), "stopping_criteria": self.stopping_criteria(),
"max_generation_length": self.max_generation_length(), "max_generation_length": self.max_generation_length(),
"num_fewshot": args["num_fewshot"],
} }
cont_request = rf.greedy_until(ctx, request_args) cont_request = rf.greedy_until(ctx, request_args)
_requests.append(cont_request) _requests.append(cont_request)
......
...@@ -206,7 +206,8 @@ def evaluate( ...@@ -206,7 +206,8 @@ def evaluate(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
) )
fewshotex_logging_info["doc_id"] = original_doc_id fewshotex_logging_info["doc_id"] = original_doc_id
reqs = task.construct_requests(doc, ctx) args = {"num_fewshot": num_fewshot}
reqs = task.construct_requests(doc, ctx, args)
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
reqs = [reqs] reqs = [reqs]
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
......
...@@ -149,14 +149,23 @@ class HFLM(BaseLM): ...@@ -149,14 +149,23 @@ class HFLM(BaseLM):
EOSCriteria(self.tokenizer.eos_token) EOSCriteria(self.tokenizer.eos_token)
]) ])
def _model_generate(self, context, max_length, stopping_criteria_ids): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
# stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
generations = self.gpt2.generate(
context, if num_fewshot == 0:
max_length=max_length, generations = self.gpt2.generate(
# stopping_criteria=stopping_criteria, context,
do_sample=False, max_length=max_length,
) eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.gpt2.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
# Remove the context from the generations # Remove the context from the generations
return generations[0, context.shape[1] :] return generations[0, context.shape[1] :]
......
...@@ -186,11 +186,21 @@ class T5LM(BaseLM): ...@@ -186,11 +186,21 @@ class T5LM(BaseLM):
EOSCriteria(self.tokenizer.eos_token) EOSCriteria(self.tokenizer.eos_token)
]) ])
def _model_generate(self, context, max_length, stopping_criteria_ids): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
# stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t5.generate(
context, if num_fewshot == 0:
max_length=max_length, generations = self.t5.generate(
# stopping_criteria=stopping_criteria, context,
do_sample=False, max_length=max_length,
)[0] eos_token_id=self.eot_token_id,
do_sample=False,
)
else:
generations = self.t5.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
return generations[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment