Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
83fd78a2
Unverified
Commit
83fd78a2
authored
May 02, 2024
by
bcicc
Committed by
GitHub
May 02, 2024
Browse files
vllm lora support (#1756)
* vllm lora support * remove print * version check, rename lora kwarg
parent
caaf9ab6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
5 deletions
+26
-5
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+26
-5
No files found.
lm_eval/models/vllm_causallms.py
View file @
83fd78a2
...
...
@@ -21,10 +21,14 @@ from lm_eval.utils import (
try
:
import
ray
from
vllm
import
LLM
,
SamplingParams
if
parse_version
(
version
(
"vllm"
))
>
parse_version
(
"0.3.0"
):
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
except
ModuleNotFoundError
:
pass
eval_logger
=
eval_logger
...
...
@@ -55,6 +59,7 @@ class VLLM(TemplateLM):
gpu_memory_utilization
:
float
=
0.9
,
device
:
str
=
"cuda"
,
data_parallel_size
:
int
=
1
,
lora_local_path
:
str
=
None
,
**
kwargs
,
):
super
().
__init__
()
...
...
@@ -127,6 +132,14 @@ class VLLM(TemplateLM):
self
.
_max_gen_toks
=
max_gen_toks
if
lora_local_path
is
not
None
:
assert
parse_version
(
version
(
"vllm"
))
>
parse_version
(
"0.3.0"
),
"lora adapters only compatible with vllm > v0.3.0."
self
.
lora_request
=
LoRARequest
(
"finetuned"
,
1
,
lora_local_path
)
else
:
self
.
lora_request
=
None
@
property
def
eot_token_id
(
self
):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
...
...
@@ -223,11 +236,19 @@ class VLLM(TemplateLM):
# flatten results
return
undistribute
(
results
)
outputs
=
self
.
model
.
generate
(
prompt_token_ids
=
requests
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
if
self
.
batch_size
==
"auto"
else
False
,
)
if
self
.
lora_request
is
not
None
:
outputs
=
self
.
model
.
generate
(
prompt_token_ids
=
requests
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
if
self
.
batch_size
==
"auto"
else
False
,
lora_request
=
self
.
lora_request
,
)
else
:
outputs
=
self
.
model
.
generate
(
prompt_token_ids
=
requests
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
if
self
.
batch_size
==
"auto"
else
False
,
)
return
outputs
def
loglikelihood_rolling
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment