Unverified Commit 83fd78a2 authored by bcicc's avatar bcicc Committed by GitHub
Browse files

vllm lora support (#1756)

* vllm lora support

* remove print

* version check, rename lora kwarg
parent caaf9ab6
...@@ -21,10 +21,14 @@ from lm_eval.utils import ( ...@@ -21,10 +21,14 @@ from lm_eval.utils import (
try: try:
import ray import ray
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
if parse_version(version("vllm")) > parse_version("0.3.0"):
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
eval_logger = eval_logger eval_logger = eval_logger
...@@ -55,6 +59,7 @@ class VLLM(TemplateLM): ...@@ -55,6 +59,7 @@ class VLLM(TemplateLM):
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
device: str = "cuda", device: str = "cuda",
data_parallel_size: int = 1, data_parallel_size: int = 1,
lora_local_path: str = None,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
...@@ -127,6 +132,14 @@ class VLLM(TemplateLM): ...@@ -127,6 +132,14 @@ class VLLM(TemplateLM):
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
if lora_local_path is not None:
assert parse_version(version("vllm")) > parse_version(
"0.3.0"
), "lora adapters only compatible with vllm > v0.3.0."
self.lora_request = LoRARequest("finetuned", 1, lora_local_path)
else:
self.lora_request = None
@property @property
def eot_token_id(self): def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
...@@ -223,11 +236,19 @@ class VLLM(TemplateLM): ...@@ -223,11 +236,19 @@ class VLLM(TemplateLM):
# flatten results # flatten results
return undistribute(results) return undistribute(results)
outputs = self.model.generate( if self.lora_request is not None:
prompt_token_ids=requests, outputs = self.model.generate(
sampling_params=sampling_params, prompt_token_ids=requests,
use_tqdm=True if self.batch_size == "auto" else False, sampling_params=sampling_params,
) use_tqdm=True if self.batch_size == "auto" else False,
lora_request=self.lora_request,
)
else:
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
)
return outputs return outputs
def loglikelihood_rolling( def loglikelihood_rolling(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment