Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b7645476
Unverified
Commit
b7645476
authored
Aug 08, 2024
by
Isotr0py
Committed by
GitHub
Aug 07, 2024
Browse files
[Bugfix] Fix input processor for InternVL2 model (#7164)
Co-authored-by:
Cyrus Leung
<
cyrus.tl.leung@gmail.com
>
parent
ab0f5e28
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
34 deletions
+73
-34
tests/models/test_internvl.py
tests/models/test_internvl.py
+19
-4
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+54
-30
No files found.
tests/models/test_internvl.py
View file @
b7645476
...
...
@@ -5,6 +5,7 @@ import pytest
import
torch
from
huggingface_hub
import
snapshot_download
from
PIL.Image
import
Image
from
transformers
import
AutoConfig
from
vllm.model_executor.models.internvl
import
(
IMG_CONTEXT
,
IMG_END
,
IMG_START
,
...
...
@@ -26,10 +27,15 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
# we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN
=
[
"*.json"
,
"*.py"
,
"*.safetensors"
,
"*.txt"
,
"*.model"
]
models
=
[
snapshot_download
(
"OpenGVLab/InternVL2-1B"
),
snapshot_download
(
"OpenGVLab/InternVL2-2B"
),
# snapshot_download("OpenGVLab/InternVL2-4B"), # broken
snapshot_download
(
"OpenGVLab/InternVL2-1B"
,
allow_patterns
=
DOWNLOAD_PATTERN
),
snapshot_download
(
"OpenGVLab/InternVL2-2B"
,
allow_patterns
=
DOWNLOAD_PATTERN
),
# Broken due to outdated implementation of Phi-3
# See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3
# snapshot_download("OpenGVLab/InternVL2-4B"),
]
...
...
@@ -41,8 +47,17 @@ class InternVLProcessor:
self
.
tokenizer
=
hf_runner
.
tokenizer
self
.
dtype
=
hf_runner
.
model
.
dtype
self
.
config
=
AutoConfig
.
from_pretrained
(
hf_runner
.
model_name
)
self
.
vision_config
=
self
.
config
.
vision_config
self
.
use_thumbnail
=
self
.
config
.
use_thumbnail
self
.
min_num
=
self
.
config
.
min_dynamic_patch
self
.
max_num
=
self
.
config
.
max_dynamic_patch
self
.
image_size
=
self
.
vision_config
.
image_size
def
__call__
(
self
,
text
:
str
,
images
:
Image
,
**
kwargs
):
pixel_values
=
image_to_pixel_values
(
images
).
to
(
self
.
dtype
)
pixel_values
=
image_to_pixel_values
(
images
,
self
.
image_size
,
self
.
min_num
,
self
.
max_num
,
self
.
use_thumbnail
).
to
(
self
.
dtype
)
num_patches_list
=
[
pixel_values
.
shape
[
0
]]
for
num_patches
in
num_patches_list
:
context_tokens
=
IMG_CONTEXT
*
self
.
num_image_token
*
num_patches
...
...
vllm/model_executor/models/internvl.py
View file @
b7645476
...
...
@@ -38,9 +38,6 @@ IMG_CONTEXT = '<IMG_CONTEXT>'
IMAGENET_MEAN
=
(
0.485
,
0.456
,
0.406
)
IMAGENET_STD
=
(
0.229
,
0.224
,
0.225
)
MAX_IMAGE_FEATURE_SIZE_WIDTH
=
3000
MAX_IMAGE_FEATURE_SIZE_HEIGHT
=
500
class
InternVLImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
...
...
@@ -84,11 +81,9 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
return
best_ratio
def
calculate_num_blocks
(
orig_width
:
int
,
orig_height
:
int
,
min_num
=
1
,
max_num
=
6
,
image_size
=
448
):
def
calculate_num_blocks
(
orig_width
:
int
,
orig_height
:
int
,
min_num
:
int
,
max_num
:
int
,
image_size
:
int
)
->
Tuple
[
int
,
int
,
int
]:
aspect_ratio
=
orig_width
/
orig_height
# calculate the existing image aspect ratio
...
...
@@ -110,11 +105,9 @@ def calculate_num_blocks(orig_width: int,
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def
dynamic_preprocess
(
image
,
min_num
=
1
,
max_num
=
6
,
image_size
=
448
,
use_thumbnail
=
False
):
def
dynamic_preprocess
(
image
:
Image
.
Image
,
min_num
:
int
,
max_num
:
int
,
image_size
:
int
,
use_thumbnail
:
int
)
->
List
[
Image
.
Image
]:
orig_width
,
orig_height
=
image
.
size
blocks
,
target_width
,
target_height
=
calculate_num_blocks
(
...
...
@@ -138,12 +131,14 @@ def dynamic_preprocess(image,
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def
image_to_pixel_values
(
image
:
Image
.
Image
,
input_size
=
448
,
max_num
=
6
):
def
image_to_pixel_values
(
image
:
Image
.
Image
,
input_size
:
int
,
min_num
:
int
,
max_num
:
int
,
use_thumbnail
:
bool
)
->
torch
.
Tensor
:
transform
=
build_transform
(
input_size
=
input_size
)
images
=
dynamic_preprocess
(
image
,
min_num
=
min_num
,
max_num
=
max_num
,
image_size
=
input_size
,
use_thumbnail
=
True
,
max_num
=
max_num
)
use_thumbnail
=
use_thumbnail
)
pixel_values
=
[
transform
(
image
)
for
image
in
images
]
pixel_values
=
torch
.
stack
(
pixel_values
)
return
pixel_values
...
...
@@ -159,12 +154,18 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
def
get_max_internvl_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
vision_config
=
hf_config
.
vision_config
use_thumbnail
=
hf_config
.
use_thumbnail
max_dynamic_patch
=
hf_config
.
max_dynamic_patch
if
use_thumbnail
:
max_dynamic_patch
+=
1
downsample_ratio
=
hf_config
.
downsample_ratio
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
downsample_ratio
=
hf_config
.
downsample_ratio
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
return
num_patches
*
7
return
num_patches
*
max_dynamic_patch
def
input_processor_for_internvl
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
...
...
@@ -176,21 +177,27 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
vision_config
=
hf_config
.
vision_config
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
downsample_ratio
=
hf_config
.
downsample_ratio
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
Image
.
Image
):
width
,
height
=
image_data
.
size
num_blocks
,
_
,
_
=
calculate_num_blocks
(
width
,
height
)
min_num
=
hf_config
.
min_dynamic_patch
max_num
=
hf_config
.
max_dynamic_patch
num_blocks
,
_
,
_
=
calculate_num_blocks
(
width
,
height
,
min_num
,
max_num
,
image_size
)
# add thumbnail image if num_blocks > 1
if
hf_config
.
use_thumbnail
and
num_blocks
>
1
:
num_blocks
+=
1
elif
isinstance
(
image_data
,
torch
.
Tensor
):
raise
NotImplementedError
(
"Embeddings input is not supported yet"
)
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
downsample_ratio
=
hf_config
.
downsample_ratio
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
...
...
@@ -198,8 +205,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
]
if
prompt
is
None
:
prompt
=
tokenizer
.
decode
(
prompt_token_ids
)
image_prompt
=
IMG_START
+
IMG_CONTEXT
*
(
num_blocks
+
1
)
*
num_patches
+
IMG_END
image_prompt
=
IMG_START
+
IMG_CONTEXT
*
num_blocks
*
num_patches
+
IMG_END
new_prompt
=
prompt
.
replace
(
'<image>'
,
image_prompt
,
1
)
new_prompt_token_ids
=
tokenizer
.
encode
(
new_prompt
)
...
...
@@ -209,8 +215,19 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
def
input_mapper_for_internvl
(
ctx
:
InputContext
,
data
:
object
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
use_thumbnail
=
hf_config
.
use_thumbnail
min_num
=
hf_config
.
min_dynamic_patch
max_num
=
hf_config
.
max_dynamic_patch
image_size
=
hf_config
.
vision_config
.
image_size
if
isinstance
(
data
,
Image
.
Image
):
data
=
image_to_pixel_values
(
data
)
data
=
image_to_pixel_values
(
data
,
image_size
,
min_num
,
max_num
,
use_thumbnail
=
use_thumbnail
)
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
...
...
@@ -240,10 +257,17 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
add_special_tokens
=
False
)[
0
],
image_feature_size_override
=
image_feature_size
,
)
image_size
=
vision_config
.
image_size
min_num
=
hf_config
.
min_dynamic_patch
max_num
=
hf_config
.
max_dynamic_patch
max_image_width
=
max_num
*
image_size
max_image_height
=
min_num
*
image_size
mm_data
=
dummy_image_for_clip
(
vision_config
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
image_width_override
=
max_image_width
,
image_height_override
=
max_image_height
,
)
return
seq_data
,
mm_data
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment