Commit 21d897db authored by cjlovering's avatar cjlovering
Browse files

Updated the requests so that its easier to understand.

parent 4f85bcf9
......@@ -345,25 +345,27 @@ class BaseLM(LM):
reord = utils.Reorderer(requests, _collate)
for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str):
until = [until]
max_length = None
elif isinstance(until, list) and len(until) == 2:
until, max_length = [until[0]], until[1]
elif isinstance(until, list):
max_length = None
for context, request_args in tqdm(reord.get_reordered()):
stopping_criteria = request_args["stopping_criteria"]
max_generation_length = request_args["max_generation_length"]
assert isinstance(stopping_criteria, str) or stopping_criteria is None
assert (
isinstance(max_generation_length, int) or max_generation_length is None
)
until = [stopping_criteria]
primary_until = self.tok_encode(until[0])
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
if max_length is not None:
max_length = min(max_length, context_enc.shape[1] + self.max_gen_toks)
else:
if max_generation_length is None:
max_length = context_enc.shape[1] + self.max_gen_toks
else:
max_length = min(
max_generation_length, context_enc.shape[1] + self.max_gen_toks
)
cont = self._model_generate(
context_enc,
max_length,
......@@ -720,9 +722,11 @@ class PromptSourceTask(Task):
else:
# If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
cont_request = rf.greedy_until(
ctx, [self.stopping_criteria(), self.max_generation_length()]
)
request_args = {
"stopping_criteria": self.stopping_criteria(),
"max_generation_length": self.max_generation_length(),
}
cont_request = rf.greedy_until(ctx, request_args)
_requests.append(cont_request)
return _requests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment