Commit 37f10cad authored by baberabb's avatar baberabb
Browse files

add auto batching

parent 2c20df08
from collections import defaultdict from collections import defaultdict
from typing import List, Tuple, Optional, Literal from typing import List, Tuple, Optional, Literal, Union
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import LM
...@@ -27,7 +27,7 @@ class VLLM(LM): ...@@ -27,7 +27,7 @@ class VLLM(LM):
quantization: Optional[Literal["awq"]] = None, quantization: Optional[Literal["awq"]] = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
swap_space: int = 4, swap_space: int = 4,
batch_size: int = 1, batch_size: Union[str, int] = 1,
max_batch_size=None, max_batch_size=None,
max_length: int = None, max_length: int = None,
seed: int = 1234, seed: int = 1234,
...@@ -206,7 +206,7 @@ class VLLM(LM): ...@@ -206,7 +206,7 @@ class VLLM(LM):
for key, re_ord in re_ords.items(): for key, re_ord in re_ords.items():
chunks = utils.chunks( chunks = utils.chunks(
re_ord.get_reordered(), re_ord.get_reordered(),
n=self.batch_size, n=self.batch_size if self.batch_size != "auto" else 0,
fn=None, fn=None,
) )
for chunk in chunks: for chunk in chunks:
...@@ -285,7 +285,7 @@ class VLLM(LM): ...@@ -285,7 +285,7 @@ class VLLM(LM):
chunks = utils.chunks( chunks = utils.chunks(
re_ord.get_reordered(), re_ord.get_reordered(),
n=self.batch_size, n=self.batch_size if self.batch_size != "auto" else 0,
fn=None, fn=None,
) )
pbar = tqdm(total=len(requests), disable=disable_tqdm) pbar = tqdm(total=len(requests), disable=disable_tqdm)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment