Unverified Commit 5b0d95a0 authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge pull request #2 from cjlovering/cjlovering/request_args

Updated the requests so that its easier to understand.
parents a3a9a7c2 21d897db
...@@ -345,25 +345,27 @@ class BaseLM(LM): ...@@ -345,25 +345,27 @@ class BaseLM(LM):
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for context, until in tqdm(reord.get_reordered()): for context, request_args in tqdm(reord.get_reordered()):
if isinstance(until, str): stopping_criteria = request_args["stopping_criteria"]
until = [until] max_generation_length = request_args["max_generation_length"]
max_length = None
elif isinstance(until, list) and len(until) == 2:
until, max_length = [until[0]], until[1]
elif isinstance(until, list):
max_length = None
assert isinstance(stopping_criteria, str) or stopping_criteria is None
assert (
isinstance(max_generation_length, int) or max_generation_length is None
)
until = [stopping_criteria]
primary_until = self.tok_encode(until[0]) primary_until = self.tok_encode(until[0])
context_enc = torch.tensor( context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]] [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device) ).to(self.device)
if max_length is not None: if max_generation_length is None:
max_length = min(max_length, context_enc.shape[1] + self.max_gen_toks)
else:
max_length = context_enc.shape[1] + self.max_gen_toks max_length = context_enc.shape[1] + self.max_gen_toks
else:
max_length = min(
max_generation_length, context_enc.shape[1] + self.max_gen_toks
)
cont = self._model_generate( cont = self._model_generate(
context_enc, context_enc,
max_length, max_length,
...@@ -720,9 +722,11 @@ class PromptSourceTask(Task): ...@@ -720,9 +722,11 @@ class PromptSourceTask(Task):
else: else:
# If not, then this is a generation prompt. # If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings. # NOTE: In the future, target will be a list of strings.
cont_request = rf.greedy_until( request_args = {
ctx, [self.stopping_criteria(), self.max_generation_length()] "stopping_criteria": self.stopping_criteria(),
) "max_generation_length": self.max_generation_length(),
}
cont_request = rf.greedy_until(ctx, request_args)
_requests.append(cont_request) _requests.append(cont_request)
return _requests return _requests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment