Unverified Commit 2a47159c authored by NanoCode012's avatar NanoCode012 Committed by GitHub
Browse files

fix: passing max_length to vllm engine args (#1124)

* fix: passing max_length to vllm engine args

* feat: add `max_model_len`

* chore: lint
parent c4f8c40e
......@@ -46,6 +46,7 @@ class VLLM(LM):
batch_size: Union[str, int] = 1,
max_batch_size=None,
max_length: int = None,
max_model_len: int = None,
seed: int = 1234,
gpu_memory_utilization: float = 0.9,
device: str = "cuda",
......@@ -62,6 +63,11 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
)
assert "cuda" in device or device is None, "vLLM only supports CUDA"
assert (
max_length is None or max_model_len is None
), "Either max_length or max_model_len may be provided, but not both"
self._max_length = max_model_len if max_model_len is not None else max_length
self.tensor_parallel_size = int(tensor_parallel_size)
self.data_parallel_size = int(data_parallel_size)
self.model_args = {
......@@ -74,6 +80,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
"tokenizer_revision": tokenizer_revision,
"trust_remote_code": trust_remote_code,
"tensor_parallel_size": int(tensor_parallel_size),
"max_model_len": int(self._max_length) if self._max_length else None,
"swap_space": int(swap_space),
"quantization": quantization,
"seed": int(seed),
......@@ -89,7 +96,6 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
tokenizer_revision=tokenizer_revision,
)
self.batch_size = batch_size
self._max_length = max_length
self._max_gen_toks = max_gen_toks
@property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment