Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e3dd0692
Unverified
Commit
e3dd0692
authored
Sep 24, 2024
by
zifeitong
Committed by
GitHub
Sep 25, 2024
Browse files
[BugFix] Propagate 'trust_remote_code' setting in internvl and minicpmv (#8250)
parent
fc3afc20
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
126 additions
and
41 deletions
+126
-41
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+9
-6
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+108
-29
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+9
-6
No files found.
vllm/model_executor/models/internvl.py
View file @
e3dd0692
...
...
@@ -230,8 +230,9 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
prompt
=
llm_inputs
.
get
(
"prompt"
)
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
]
...
...
@@ -278,8 +279,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
use_thumbnail
=
use_thumbnail
)
for
img
in
data
]
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
image_token_id
=
tokenizer
.
encode
(
IMG_CONTEXT
,
add_special_tokens
=
False
,
return_tensors
=
"pt"
)[
0
]
...
...
@@ -298,8 +300,9 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
seq_data
=
dummy_seq_data_for_clip
(
vision_config
,
...
...
vllm/model_executor/models/minicpmv.py
View file @
e3dd0692
...
...
@@ -33,6 +33,7 @@ from PIL import Image
from
torch
import
nn
from
torch.nn.init
import
trunc_normal_
from
transformers
import
PretrainedConfig
from
typing_extensions
import
NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
...
...
@@ -52,6 +53,7 @@ from vllm.model_executor.models.minicpm import MiniCPMModel
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
...
...
@@ -64,6 +66,17 @@ _KEYS_TO_MODIFY_MAPPING = {
}
class
MiniCPMVImageInput
(
TypedDict
):
"""Input mapper input with auxiliary data for computing image bounds."""
image
:
Image
.
Image
# Image bounds token ids in 0-dim scaler tensor.
im_start_id
:
torch
.
Tensor
im_end_id
:
torch
.
Tensor
slice_start_id
:
NotRequired
[
torch
.
Tensor
]
slice_end_id
:
NotRequired
[
torch
.
Tensor
]
class
MiniCPMVImagePixelInputs
(
TypedDict
):
pixel_values
:
List
[
torch
.
Tensor
]
"""
...
...
@@ -88,8 +101,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
"""
MiniCPMVImageInputs
=
MiniCPMVImagePixelInputs
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
...
...
@@ -234,6 +245,25 @@ class Resampler2_5(BaseResampler):
return
x
def
_build_image_input
(
ctx
:
InputContext
,
image
:
Image
.
Image
)
->
MiniCPMVImageInput
:
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
,
trust_remote_code
=
ctx
.
model_config
.
trust_remote_code
)
if
hasattr
(
tokenizer
,
"slice_start_id"
):
return
MiniCPMVImageInput
(
image
=
image
,
im_start_id
=
torch
.
tensor
(
tokenizer
.
im_start_id
),
im_end_id
=
torch
.
tensor
(
tokenizer
.
im_end_id
),
slice_start_id
=
torch
.
tensor
(
tokenizer
.
slice_start_id
),
slice_end_id
=
torch
.
tensor
(
tokenizer
.
slice_end_id
))
else
:
return
MiniCPMVImageInput
(
image
=
image
,
im_start_id
=
torch
.
tensor
(
tokenizer
.
im_start_id
),
im_end_id
=
torch
.
tensor
(
tokenizer
.
im_end_id
))
def
get_version_by_config
(
config
:
PretrainedConfig
)
->
Tuple
[
int
,
...]:
version_float
=
getattr
(
config
,
"version"
,
None
)
...
...
@@ -257,10 +287,13 @@ def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
return
SequenceData
.
from_token_counts
((
0
,
seq_len
))
def
dummy_image_for_minicpmv
(
hf_config
:
PretrainedConfig
,
num_images
:
int
):
def
dummy_image_for_minicpmv
(
ctx
:
InputContext
,
hf_config
:
PretrainedConfig
,
num_images
:
int
):
width
=
height
=
hf_config
.
image_size
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
image
=
_build_image_input
(
ctx
,
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
))
return
{
"image"
:
[
image
]
if
num_images
==
1
else
[
image
]
*
num_images
}
def
dummy_data_for_minicpmv
(
ctx
:
InputContext
,
seq_len
:
int
,
...
...
@@ -269,7 +302,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
num_images
=
mm_counts
[
"image"
]
seq_data
=
dummy_seq_data_for_minicpmv
(
seq_len
,
num_images
)
mm_data
=
dummy_image_for_minicpmv
(
hf_config
,
num_images
)
mm_data
=
dummy_image_for_minicpmv
(
ctx
,
hf_config
,
num_images
)
return
seq_data
,
mm_data
...
...
@@ -280,8 +313,9 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return
llm_inputs
model_config
=
ctx
.
model_config
version
=
get_version_by_config
(
model_config
.
hf_config
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
image_processor
=
cached_get_image_processor
(
model_config
.
tokenizer
)
def
get_placeholder
(
image_size
:
Tuple
[
int
,
int
],
num_image
:
int
):
...
...
@@ -317,6 +351,10 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
new_prompt
=
""
.
join
(
new_prompt_chunks
)
new_token_ids
=
tokenizer
.
encode
(
new_prompt
)
multi_modal_data
[
"image"
]
=
[
_build_image_input
(
ctx
,
image
)
for
image
in
images
]
llm_inputs
=
LLMInputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
...
...
@@ -325,6 +363,32 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return
llm_inputs
def
input_mapper_for_minicpmv
(
ctx
:
InputContext
,
data
:
object
):
model_config
=
ctx
.
model_config
image_processor
=
cached_get_image_processor
(
model_config
.
model
,
trust_remote_code
=
model_config
.
trust_remote_code
)
if
image_processor
is
None
:
raise
RuntimeError
(
"No HuggingFace processor is available "
"to process the image object"
)
if
not
isinstance
(
data
,
list
):
raise
ValueError
(
"Image input must be list of MiniCPMVImageInput, got (%s)"
,
data
)
batch_data
=
image_processor
\
.
preprocess
([
img
[
"image"
]
for
img
in
data
],
return_tensors
=
"pt"
)
\
.
data
if
len
(
data
)
>
0
:
batch_data
[
"im_start_id"
]
=
data
[
0
][
"im_start_id"
]
batch_data
[
"im_end_id"
]
=
data
[
0
][
"im_end_id"
]
if
"slice_start_id"
in
data
[
0
]:
batch_data
[
"slice_start_id"
]
=
data
[
0
][
"slice_start_id"
]
batch_data
[
"slice_end_id"
]
=
data
[
0
][
"slice_end_id"
]
return
MultiModalInputs
(
batch_data
)
class
MiniCPMVBaseModel
(
nn
.
Module
,
SupportsMultiModal
):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
...
...
@@ -365,7 +429,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
def
get_embedding
(
self
,
input_ids
:
torch
.
Tensor
,
image_inputs
:
Optional
[
MiniCPMVImageInputs
],
image_inputs
:
Optional
[
MiniCPMVImage
Pixel
Inputs
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
vlm_embedding
:
torch
.
Tensor
=
self
.
llm
.
embed_tokens
(
input_ids
)
if
hasattr
(
self
.
config
,
"scale_emb"
):
...
...
@@ -393,14 +457,20 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
return
vlm_embedding
,
vision_hidden_states
def
_get_image_bounds
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tokenizer
=
cached_get_tokenizer
(
self
.
config
.
_name_or_path
,
trust_remote_code
=
True
)
start_cond
=
input_ids
==
tokenizer
.
im_start_id
end_cond
=
input_ids
==
tokenizer
.
im_end_id
if
hasattr
(
tokenizer
,
"slice_start_id"
):
start_cond
|=
(
input_ids
==
tokenizer
.
slice_start_id
)
end_cond
|=
(
input_ids
==
tokenizer
.
slice_end_id
)
def
_get_image_bounds
(
self
,
input_ids
:
torch
.
Tensor
,
im_start_id
:
torch
.
Tensor
,
im_end_id
:
torch
.
Tensor
,
slice_start_id
:
Optional
[
torch
.
Tensor
]
=
None
,
slice_end_id
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# All the images in the batch should share the same special image
# bound token ids.
start_cond
=
input_ids
==
im_start_id
[
0
]
end_cond
=
input_ids
==
im_end_id
[
0
]
if
slice_start_id
is
not
None
:
start_cond
|=
(
input_ids
==
slice_start_id
[
0
])
end_cond
|=
(
input_ids
==
slice_end_id
[
0
])
image_start_tokens
,
=
torch
.
where
(
start_cond
)
image_start_tokens
+=
1
...
...
@@ -419,7 +489,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
self
,
input_ids
:
torch
.
Tensor
,
**
kwargs
:
object
,
)
->
Optional
[
MiniCPMVImageInputs
]:
)
->
Optional
[
MiniCPMVImage
Pixel
Inputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
[])
tgt_sizes
=
kwargs
.
pop
(
"tgt_sizes"
,
[])
...
...
@@ -456,8 +526,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
if
len
(
pixel_values_flat
)
==
0
:
return
None
return
MiniCPMVImageInputs
(
image_bounds
=
self
.
_get_image_bounds
(
input_ids
),
im_start_id
=
kwargs
.
pop
(
"im_start_id"
,
None
)
im_end_id
=
kwargs
.
pop
(
"im_end_id"
,
None
)
slice_start_id
=
kwargs
.
pop
(
"slice_start_id"
,
None
)
slice_end_id
=
kwargs
.
pop
(
"slice_end_id"
,
None
)
if
im_start_id
is
None
:
return
None
return
MiniCPMVImagePixelInputs
(
image_bounds
=
self
.
_get_image_bounds
(
input_ids
,
im_start_id
,
im_end_id
,
slice_start_id
,
slice_end_id
),
pixel_values
=
pixel_values_flat
,
tgt_sizes
=
torch
.
stack
(
tgt_sizes_flat
),
)
...
...
@@ -564,8 +643,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
...
...
@@ -654,8 +733,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
res
.
append
(
self
.
resampler
(
vision_embedding
,
tgt_size
))
return
torch
.
vstack
(
res
)
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"pixel_values"
]
return
self
.
get_vision_embedding
(
pixel_values
)
...
...
@@ -713,8 +792,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
vision_embedding
=
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
return
vision_embedding
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"pixel_values"
]
tgt_sizes
=
data
[
"tgt_sizes"
]
...
...
@@ -807,8 +886,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
).
last_hidden_state
return
vision_embedding
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"pixel_values"
]
tgt_sizes
=
data
[
"tgt_sizes"
]
...
...
@@ -851,7 +930,7 @@ _SUPPORT_VERSION = {
}
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_minicpmv
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_minicpmv_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_minicpmv
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_minicpmv
)
...
...
vllm/model_executor/models/qwen.py
View file @
e3dd0692
...
...
@@ -674,8 +674,9 @@ def input_processor_for_qwen(ctx: InputContext,
prompt
=
llm_inputs
.
get
(
"prompt"
)
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
]
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
torch
.
Tensor
):
num_dims
=
len
(
image_data
.
shape
)
...
...
@@ -735,8 +736,9 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
return
MultiModalInputs
()
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
image_pair_tok
=
tokenizer
.
encode
(
IMG_START
+
IMG_END
,
add_special_tokens
=
False
,
...
...
@@ -824,8 +826,9 @@ def dummy_data_for_qwen(
# We have a visual component - use images to warm up
num_images
=
mm_counts
[
"image"
]
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
# Build the image prompts with no imgpads; the tokenizer will add img pads
image_prompt
=
''
.
join
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment