Commit a2920e3d authored by baberabb's avatar baberabb
Browse files

add vllm model

parent 3d8bc82f
......@@ -74,7 +74,7 @@ class VLLM(LM):
def _model_generate(
self,
requests: List = None,
requests: List[int] = None,
generate: bool = False,
max_tokens: int = None,
stop: Optional[List[str]] = None,
......@@ -153,7 +153,7 @@ class VLLM(LM):
# batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding = self.tokenizer(context)
requests = zip((context, context_encoding.input_ids), all_gen_kwargs)
requests = list(zip((context, context_encoding.input_ids), all_gen_kwargs))
def _collate_gen(_requests):
# the negative sign on len(toks) sorts descending - this has a few advantages:
......@@ -269,7 +269,7 @@ class VLLM(LM):
inps.append(inp)
ctxlens.append(ctxlen)
outputs = self._model_generate(generate=False)
outputs = self._model_generate(requests=inps, generate=False)
for output, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
outputs, ctxlens, chunk
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment