Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2ae25f79
Unverified
Commit
2ae25f79
authored
Sep 30, 2024
by
Isotr0py
Committed by
GitHub
Sep 30, 2024
Browse files
[Model] Expose InternVL2 max_dynamic_patch as a mm_processor_kwarg (#8946)
parent
8e60afa1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
61 deletions
+90
-61
examples/offline_inference_vision_language_multi_image.py
examples/offline_inference_vision_language_multi_image.py
+1
-0
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+89
-61
No files found.
examples/offline_inference_vision_language_multi_image.py
View file @
2ae25f79
...
@@ -115,6 +115,7 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
...
@@ -115,6 +115,7 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
trust_remote_code
=
True
,
trust_remote_code
=
True
,
max_model_len
=
4096
,
max_model_len
=
4096
,
limit_mm_per_prompt
=
{
"image"
:
len
(
image_urls
)},
limit_mm_per_prompt
=
{
"image"
:
len
(
image_urls
)},
mm_processor_kwargs
=
{
"max_dynamic_patch"
:
4
},
)
)
placeholders
=
"
\n
"
.
join
(
f
"Image-
{
i
}
: <image>
\n
"
placeholders
=
"
\n
"
.
join
(
f
"Image-
{
i
}
: <image>
\n
"
...
...
vllm/model_executor/models/internvl.py
View file @
2ae25f79
...
@@ -5,8 +5,9 @@
...
@@ -5,8 +5,9 @@
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
# --------------------------------------------------------
import
re
import
re
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
from
functools
import
partial
TypedDict
,
Union
)
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -122,6 +123,20 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
...
@@ -122,6 +123,20 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
return
blocks
,
target_width
,
target_height
return
blocks
,
target_width
,
target_height
def
calculate_num_blocks_wrapper
(
hf_config
:
Dict
[
str
,
Any
],
max_dynamic_patch
:
Optional
[
int
]
=
None
):
if
max_dynamic_patch
is
None
:
max_dynamic_patch
=
hf_config
.
max_dynamic_patch
min_num
=
hf_config
.
min_dynamic_patch
image_size
=
hf_config
.
vision_config
.
image_size
use_thumbnail
=
hf_config
.
use_thumbnail
return
partial
(
calculate_num_blocks
,
min_num
=
min_num
,
max_num
=
max_dynamic_patch
,
image_size
=
image_size
,
use_thumbnail
=
use_thumbnail
)
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def
dynamic_preprocess
(
image
:
Image
.
Image
,
min_num
:
int
,
max_num
:
int
,
def
dynamic_preprocess
(
image
:
Image
.
Image
,
min_num
:
int
,
max_num
:
int
,
image_size
:
int
,
image_size
:
int
,
...
@@ -168,62 +183,85 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
...
@@ -168,62 +183,85 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
return
pixel_values
return
pixel_values
def
get_internvl_num_patches
(
image_size
:
int
,
patch_size
:
int
,
def
image_to_pixel_values_wrapper
(
hf_config
:
Dict
[
str
,
Any
],
downsample_ratio
:
float
):
max_dynamic_patch
:
Optional
[
int
]
=
None
):
image_size
=
hf_config
.
vision_config
.
image_size
min_num
=
hf_config
.
min_dynamic_patch
if
max_dynamic_patch
is
None
:
max_dynamic_patch
=
hf_config
.
max_dynamic_patch
use_thumbnail
=
hf_config
.
use_thumbnail
return
partial
(
image_to_pixel_values
,
input_size
=
image_size
,
min_num
=
min_num
,
max_num
=
max_dynamic_patch
,
use_thumbnail
=
use_thumbnail
)
def
get_internvl_num_patches
(
hf_config
:
Dict
[
str
,
Any
]):
vision_config
=
hf_config
.
vision_config
downsample_ratio
=
hf_config
.
downsample_ratio
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
return
int
(
return
int
(
get_clip_num_patches
(
image_size
=
image_size
,
patch_size
=
patch_size
)
*
get_clip_num_patches
(
image_size
=
image_size
,
patch_size
=
patch_size
)
*
(
downsample_ratio
**
2
))
(
downsample_ratio
**
2
))
def
get_max_internvl_image_tokens
(
ctx
:
InputContext
):
def
get_max_internvl_image_tokens
(
ctx
:
InputContext
,
*
,
max_dynamic_patch
:
Optional
[
int
]
=
None
):
hf_config
=
ctx
.
get_hf_config
()
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
if
max_dynamic_patch
is
None
:
max_dynamic_patch
=
hf_config
.
max_dynamic_patch
use_thumbnail
=
hf_config
.
use_thumbnail
use_thumbnail
=
hf_config
.
use_thumbnail
max_dynamic_patch
=
hf_config
.
max_dynamic_patch
if
use_thumbnail
and
max_dynamic_patch
>
1
:
if
use_thumbnail
:
max_dynamic_patch
+=
1
max_dynamic_patch
+=
1
downsample_ratio
=
hf_config
.
downsample_ratio
image_size
=
vision_config
.
image_size
num_patches
=
get_internvl_num_patches
(
hf_config
)
patch_size
=
vision_config
.
patch_size
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
return
num_patches
*
max_dynamic_patch
return
num_patches
*
max_dynamic_patch
def
input_processor_for_internvl
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
def
get_max_internvl_image_size
(
ctx
:
InputContext
,
*
,
max_dynamic_patch
:
Optional
[
int
]
=
None
):
hf_config
=
ctx
.
get_hf_config
()
image_size
=
hf_config
.
vision_config
.
image_size
if
max_dynamic_patch
is
None
:
max_dynamic_patch
=
hf_config
.
max_dynamic_patch
use_thumbnail
=
hf_config
.
use_thumbnail
if
use_thumbnail
and
max_dynamic_patch
>
1
:
max_dynamic_patch
+=
1
width
=
image_size
*
max_dynamic_patch
height
=
image_size
return
width
,
height
def
input_processor_for_internvl
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
,
*
,
max_dynamic_patch
:
Optional
[
int
]
=
None
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_inputs
return
llm_inputs
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
()
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
downsample_ratio
=
hf_config
.
downsample_ratio
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
image_data
=
multi_modal_data
[
"image"
]
image_data
=
multi_modal_data
[
"image"
]
min_num
=
hf_config
.
min_dynamic_patch
num_patches
=
get_internvl_num_patches
(
hf_config
)
max_num
=
hf_config
.
max_dynamic_patch
num_blocks_calculator
=
calculate_num_blocks_wrapper
(
use_thumbnail
=
hf_config
.
use_thumbnail
hf_config
,
max_dynamic_patch
)
if
isinstance
(
image_data
,
Image
.
Image
):
if
isinstance
(
image_data
,
Image
.
Image
):
width
,
height
=
image_data
.
size
width
,
height
=
image_data
.
size
num_blocks
,
_
,
_
=
calculate_num_blocks
(
width
,
height
,
min_num
,
num_blocks
,
_
,
_
=
num_blocks_calculator
(
width
,
height
)
max_num
,
image_size
,
use_thumbnail
)
image_feature_size
=
[
num_blocks
*
num_patches
]
image_feature_size
=
[
num_blocks
*
num_patches
]
elif
is_list_of
(
image_data
,
Image
.
Image
):
elif
is_list_of
(
image_data
,
Image
.
Image
):
image_feature_size
=
[]
image_feature_size
=
[]
for
image
in
image_data
:
for
image
in
image_data
:
width
,
height
=
image
.
size
width
,
height
=
image
.
size
num_blocks
,
_
,
_
=
calculate_num_blocks
(
width
,
height
,
min_num
,
num_blocks
,
_
,
_
=
num_blocks_calculator
(
width
,
height
)
max_num
,
image_size
,
use_thumbnail
)
image_feature_size
.
append
(
num_blocks
*
num_patches
)
image_feature_size
.
append
(
num_blocks
*
num_patches
)
elif
isinstance
(
image_data
,
torch
.
Tensor
):
elif
isinstance
(
image_data
,
torch
.
Tensor
):
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
...
@@ -253,31 +291,21 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -253,31 +291,21 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data
=
multi_modal_data
)
multi_modal_data
=
multi_modal_data
)
def
input_mapper_for_internvl
(
ctx
:
InputContext
,
data
:
object
):
def
input_mapper_for_internvl
(
ctx
:
InputContext
,
data
:
object
,
*
,
max_dynamic_patch
:
Optional
[
int
]
=
None
):
hf_config
=
ctx
.
get_hf_config
()
hf_config
=
ctx
.
get_hf_config
()
use_thumbnail
=
hf_config
.
use_thumbnail
image_pixel_values_mapper
=
image_to_pixel_values_wrapper
(
min_num
=
hf_config
.
min_dynamic_patch
hf_config
,
max_dynamic_patch
)
max_num
=
hf_config
.
max_dynamic_patch
image_size
=
hf_config
.
vision_config
.
image_size
if
isinstance
(
data
,
Image
.
Image
):
if
isinstance
(
data
,
Image
.
Image
):
data
=
image_to_pixel_values
(
data
,
data
=
image_pixel_values_mapper
(
data
)
image_size
,
min_num
,
max_num
,
use_thumbnail
=
use_thumbnail
)
# Add an N dimension for number of images per prompt (currently 1).
# Add an N dimension for number of images per prompt (currently 1).
data
=
data
.
unsqueeze
(
0
)
data
=
data
.
unsqueeze
(
0
)
elif
is_list_of
(
data
,
Image
.
Image
):
elif
is_list_of
(
data
,
Image
.
Image
):
# we can't stack here because the images may have different num_patches
# we can't stack here because the images may have different num_patches
data
=
[
data
=
[
image_pixel_values_mapper
(
img
)
for
img
in
data
]
image_to_pixel_values
(
img
,
image_size
,
min_num
,
max_num
,
use_thumbnail
=
use_thumbnail
)
for
img
in
data
]
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
model_config
.
tokenizer
,
...
@@ -292,20 +320,24 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
...
@@ -292,20 +320,24 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
})
})
def
dummy_data_for_internvl
(
ctx
:
InputContext
,
seq_len
:
int
,
def
dummy_data_for_internvl
(
ctx
:
InputContext
,
mm_counts
:
Mapping
[
str
,
int
]):
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
*
,
max_dynamic_patch
:
Optional
[
int
]
=
None
):
num_images
=
mm_counts
[
"image"
]
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_internvl_image_tokens
(
ctx
)
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
()
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
image_feature_size
=
get_max_internvl_image_tokens
(
ctx
,
max_dynamic_patch
=
max_dynamic_patch
)
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
trust_remote_code
=
model_config
.
trust_remote_code
)
seq_data
=
dummy_seq_data_for_clip
(
seq_data
=
dummy_seq_data_for_clip
(
vision_config
,
hf_config
.
vision_config
,
seq_len
,
seq_len
,
num_images
,
num_images
,
image_token_id
=
tokenizer
.
encode
(
IMG_CONTEXT
,
image_token_id
=
tokenizer
.
encode
(
IMG_CONTEXT
,
...
@@ -313,14 +345,11 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
...
@@ -313,14 +345,11 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
image_feature_size_override
=
image_feature_size
,
image_feature_size_override
=
image_feature_size
,
)
)
image_size
=
vision_config
.
image_size
max_image_width
,
max_image_height
=
get_max_internvl_image_size
(
min_num
=
hf_config
.
min_dynamic_patch
ctx
,
max_dynamic_patch
=
max_dynamic_patch
)
max_num
=
hf_config
.
max_dynamic_patch
max_image_width
=
max_num
*
image_size
max_image_height
=
min_num
*
image_size
mm_data
=
dummy_image_for_clip
(
mm_data
=
dummy_image_for_clip
(
vision_config
,
hf_config
.
vision_config
,
num_images
,
num_images
,
image_width_override
=
max_image_width
,
image_width_override
=
max_image_width
,
image_height_override
=
max_image_height
,
image_height_override
=
max_image_height
,
...
@@ -470,7 +499,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
...
@@ -470,7 +499,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
self
,
self
,
image_input
:
InternVLImageInputs
,
image_input
:
InternVLImageInputs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
image_input
[
"type"
]
==
"image_embeds"
:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"data"
]
return
image_input
[
"data"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment