Commit a2920e3d authored by baberabb's avatar baberabb
Browse files

add vllm model

parent 3d8bc82f
...@@ -74,7 +74,7 @@ class VLLM(LM): ...@@ -74,7 +74,7 @@ class VLLM(LM):
def _model_generate( def _model_generate(
self, self,
requests: List = None, requests: List[int] = None,
generate: bool = False, generate: bool = False,
max_tokens: int = None, max_tokens: int = None,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
...@@ -153,7 +153,7 @@ class VLLM(LM): ...@@ -153,7 +153,7 @@ class VLLM(LM):
# batch tokenize contexts # batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests)) context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding = self.tokenizer(context) context_encoding = self.tokenizer(context)
requests = zip((context, context_encoding.input_ids), all_gen_kwargs) requests = list(zip((context, context_encoding.input_ids), all_gen_kwargs))
def _collate_gen(_requests): def _collate_gen(_requests):
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
...@@ -269,7 +269,7 @@ class VLLM(LM): ...@@ -269,7 +269,7 @@ class VLLM(LM):
inps.append(inp) inps.append(inp)
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
outputs = self._model_generate(generate=False) outputs = self._model_generate(requests=inps, generate=False)
for output, ctxlen, (cache_key, context_enc, continuation_enc) in zip( for output, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
outputs, ctxlens, chunk outputs, ctxlens, chunk
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment