Unverified Commit 08074cf9 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Update vllm_causallms.py

parent 581dd9ff
...@@ -7,11 +7,13 @@ import copy ...@@ -7,11 +7,13 @@ import copy
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval import utils from lm_eval import utils
try: try:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -43,10 +45,10 @@ class VLLM(LM): ...@@ -43,10 +45,10 @@ class VLLM(LM):
import vllm import vllm
except ModuleNotFoundError: except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'vllm' LM type, but package `vllm` is not installed. \ "attempted to use 'vllm' LM type, but package `vllm` is not installed. \
please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`", please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`",
) )
assert "cuda" in device or device is None, "vLLM only supports CUDA" assert "cuda" in device or device is None, "vLLM only supports CUDA"
self.model = LLM( self.model = LLM(
model=pretrained, model=pretrained,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment