"...resnet50_tensorflow.git" did not exist on "2a595d51dd5c1f7a84dc47939fbb61b0f6e991b2"
Commit 4f79fa7b authored by baberabb's avatar baberabb
Browse files

fix args

parent 953bfdd2
...@@ -9,6 +9,8 @@ from lm_eval.api.registry import register_model ...@@ -9,6 +9,8 @@ from lm_eval.api.registry import register_model
from lm_eval import utils from lm_eval import utils
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
eval_logger = utils.eval_logger
@register_model("vllm") @register_model("vllm")
class VLLM(LM): class VLLM(LM):
...@@ -22,24 +24,34 @@ class VLLM(LM): ...@@ -22,24 +24,34 @@ class VLLM(LM):
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_mode: Literal["auto", "slow"] = "auto",
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
quantization: Optional[str] = None, quantization: Optional[Literal["awq"]] = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
swap_space: int = 4, swap_space: int = 4,
batch_size: int = 1, batch_size: int = 1,
max_batch_size=None,
max_length: int = None, max_length: int = None,
seed: int = 1234,
gpu_memory_utilization: int = 0.9,
device: str = "cuda",
): ):
super().__init__() super().__init__()
assert "cuda" in device or device is None, "vLLM only supports CUDA"
if batch_size == "auto":
eval_logger.info(
"vllm does not support auto selection. Setting batch_size to 8"
)
batch_size = 8
self.model = LLM( self.model = LLM(
model=pretrained, model=pretrained,
gpu_memory_utilization=0.9, gpu_memory_utilization=gpu_memory_utilization,
revision=revision, revision=revision,
dtype=dtype, dtype=dtype,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tensor_parallel_size=int(tensor_parallel_size), tensor_parallel_size=int(tensor_parallel_size),
swap_space=swap_space, swap_space=int(swap_space),
quantization=quantization, quantization=quantization,
seed=seed,
) )
self.tokenizer = self.model.get_tokenizer() self.tokenizer = self.model.get_tokenizer()
self.batch_size = batch_size self.batch_size = batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment