Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
7b7dd042
Commit
7b7dd042
authored
Feb 04, 2025
by
Baber
Browse files
modularize vllm
parent
8fada609
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
10 deletions
+29
-10
lm_eval/models/vllm_vlms.py
lm_eval/models/vllm_vlms.py
+29
-10
No files found.
lm_eval/models/vllm_vlms.py
View file @
7b7dd042
import
copy
from
typing
import
Dict
,
List
,
Optional
import
json
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
import
transformers
from
more_itertools
import
distribute
...
...
@@ -29,6 +30,13 @@ except ModuleNotFoundError:
DEFAULT_IMAGE_PLACEHOLDER
=
"<image>"
class
JsonChatStr
(
NamedTuple
):
prompt
:
str
def
encode
(
self
,
encoding
):
return
self
.
prompt
.
encode
(
encoding
)
@
register_model
(
"vllm-vlm"
)
class
VLLM_VLM
(
VLLM
):
MULTIMODAL
=
True
...
...
@@ -43,6 +51,7 @@ class VLLM_VLM(VLLM):
max_images
:
int
=
999
,
**
kwargs
,
):
self
.
pretrained
=
pretrained
if
max_images
!=
999
:
kwargs
[
"limit_mm_per_prompt"
]
=
{
"image"
:
max_images
}
eval_logger
.
info
(
f
"Setting limit_mm_per_prompt[image] to
{
max_images
}
"
)
...
...
@@ -90,6 +99,12 @@ class VLLM_VLM(VLLM):
outputs
.
append
(
inputs
)
return
outputs
def
_generate
(
self
,
model
,
*
args
,
**
kwargs
):
if
"pixtral"
not
in
self
.
pretrained
:
return
model
.
generate
(
*
args
,
**
kwargs
)
else
:
model
.
chat
(
**
kwargs
)
def
_model_generate
(
self
,
requests
:
List
[
List
[
dict
]]
=
None
,
...
...
@@ -116,7 +131,7 @@ class VLLM_VLM(VLLM):
model_args
:
dict
,
sampling_params
,
requests
:
List
[
List
[
dict
]]
):
llm
=
LLM
(
**
model_args
)
return
llm
.
generate
(
requests
,
sampling_params
=
sampling_params
)
return
self
.
_
generate
(
llm
,
requests
,
sampling_params
=
sampling_params
)
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers
...
...
@@ -130,14 +145,16 @@ class VLLM_VLM(VLLM):
return
undistribute
(
results
)
if
self
.
lora_request
is
not
None
:
outputs
=
self
.
model
.
generate
(
outputs
=
self
.
_generate
(
self
.
model
,
requests
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
if
self
.
batch_size
==
"auto"
else
False
,
lora_request
=
self
.
lora_request
,
)
else
:
outputs
=
self
.
model
.
generate
(
outputs
=
self
.
_generate
(
self
.
model
,
requests
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
if
self
.
batch_size
==
"auto"
else
False
,
...
...
@@ -194,12 +211,14 @@ class VLLM_VLM(VLLM):
raise
ValueError
(
f
"Mismatch in image placeholder count. Expected:
{
expected_image_count
}
, Actual:
{
actual_image_count
}
"
)
return
self
.
processor
.
apply_chat_template
(
chat_history
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
not
add_generation_prompt
,
)
if
hasattr
(
self
.
processor
,
"apply_chat_template"
):
return
self
.
processor
.
apply_chat_template
(
chat_history
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
not
add_generation_prompt
,
)
else
:
return
JsonChatStr
(
json
.
dumps
(
chat_history
))
def
generate_until
(
self
,
requests
:
List
[
Instance
],
disable_tqdm
:
bool
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment