"configs/vscode:/vscode.git/clone" did not exist on "0241806b5e2d43374c530888566c9253bd356751"
Commit e5491709 authored by baberabb's avatar baberabb
Browse files

fix greedy_until

parent 24f4e8d7
...@@ -161,7 +161,9 @@ class VLLM(LM): ...@@ -161,7 +161,9 @@ class VLLM(LM):
# batch tokenize contexts # batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests)) context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding = self.tokenizer(context) context_encoding = self.tokenizer(context)
requests = list(zip((context, context_encoding.input_ids), all_gen_kwargs)) requests = [
((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
]
def _collate_gen(_requests): def _collate_gen(_requests):
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
...@@ -190,7 +192,7 @@ class VLLM(LM): ...@@ -190,7 +192,7 @@ class VLLM(LM):
) )
for chunk in chunks: for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = context_and_encoding context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same # we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it. # this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0] gen_kwargs = all_gen_kwargs[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment