Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
722d46ed
Unverified
Commit
722d46ed
authored
Oct 24, 2024
by
Alex Brooks
Committed by
GitHub
Oct 24, 2024
Browse files
[Model] Compute Llava Next Max Tokens / Dummy Data From Gridpoints (#9650)
Signed-off-by:
Alex-Brooks
<
Alex.Brooks@ibm.com
>
parent
c866e007
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
93 additions
and
14 deletions
+93
-14
tests/models/decoder_only/vision_language/test_llava_next.py
tests/models/decoder_only/vision_language/test_llava_next.py
+65
-1
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+28
-13
No files found.
tests/models/decoder_only/vision_language/test_llava_next.py
View file @
722d46ed
...
@@ -3,12 +3,13 @@ from typing import List, Optional, Tuple, Type, overload
...
@@ -3,12 +3,13 @@ from typing import List, Optional, Tuple, Type, overload
import
pytest
import
pytest
from
transformers
import
AutoConfig
,
AutoModelForVision2Seq
,
AutoTokenizer
from
transformers
import
AutoConfig
,
AutoModelForVision2Seq
,
AutoTokenizer
from
vllm.inputs
import
InputContext
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.sequence
import
SampleLogprobs
from
vllm.sequence
import
SampleLogprobs
from
....conftest
import
(
IMAGE_ASSETS
,
HfRunner
,
PromptImageInput
,
VllmRunner
,
from
....conftest
import
(
IMAGE_ASSETS
,
HfRunner
,
PromptImageInput
,
VllmRunner
,
_ImageAssets
)
_ImageAssets
)
from
...utils
import
check_logprobs_close
from
...utils
import
build_model_context
,
check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT
=
4
_LIMIT_IMAGE_PER_PROMPT
=
4
...
@@ -22,6 +23,19 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
...
@@ -22,6 +23,19 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
models
=
[
"llava-hf/llava-v1.6-mistral-7b-hf"
]
models
=
[
"llava-hf/llava-v1.6-mistral-7b-hf"
]
@
pytest
.
fixture
()
def
get_max_llava_next_image_tokens
():
from
vllm.model_executor.models.llava_next
import
(
get_max_llava_next_image_tokens
)
return
get_max_llava_next_image_tokens
@
pytest
.
fixture
()
def
dummy_data_for_llava_next
():
from
vllm.model_executor.models.llava_next
import
dummy_data_for_llava_next
return
dummy_data_for_llava_next
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
Optional
[
SampleLogprobs
]],
model
:
str
):
model
:
str
):
...
@@ -281,3 +295,53 @@ def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
...
@@ -281,3 +295,53 @@ def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
num_logprobs
=
num_logprobs
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
tensor_parallel_size
=
1
,
)
)
@
pytest
.
mark
.
parametrize
(
"gridpoints,expected_max_tokens"
,
[
([[
336
,
336
]],
1176
),
([[
336
,
672
],
[
672
,
336
],
[
672
,
672
],
[
1008
,
336
],
[
336
,
1008
]],
2928
),
])
def
test_get_max_llava_next_image_tokens
(
gridpoints
,
expected_max_tokens
,
get_max_llava_next_image_tokens
):
ctx
=
build_model_context
(
model_name
=
"llava-hf/llava-v1.6-mistral-7b-hf"
)
# Update the config image_grid_pinpoints
# and calculate the resulting max tokens
ctx
.
model_config
.
hf_config
.
image_grid_pinpoints
=
gridpoints
actual_max_tokens
=
get_max_llava_next_image_tokens
(
InputContext
(
ctx
.
model_config
))
assert
expected_max_tokens
==
actual_max_tokens
@
pytest
.
mark
.
parametrize
(
"gridpoints,expected_size"
,
[
# One point; it has to be the largest
([[
336
,
336
]],
(
336
,
336
)),
# Default for most llava next models; the 2x2 tile is the largest
([[
336
,
672
],
[
672
,
336
],
[
672
,
672
],
[
1008
,
336
],
[
336
,
1008
]],
(
672
,
672
)),
# If two rectangular gridpoints are the same, the more vertical
# one has the higher feature count due to newline features
([[
336
,
672
],
[
672
,
336
]],
(
672
,
336
))
])
def
test_dummy_data_for_llava_next_feature_size
(
dummy_data_for_llava_next
,
gridpoints
,
expected_size
):
ctx
=
build_model_context
(
model_name
=
"llava-hf/llava-v1.6-mistral-7b-hf"
)
# Update the config image_grid_pinpoints
ctx
.
model_config
.
hf_config
.
image_grid_pinpoints
=
gridpoints
seq_len
=
5000
# bigger than the max feature size for any image
seq_data
,
mm_data
=
dummy_data_for_llava_next
(
ctx
,
seq_len
=
seq_len
,
mm_counts
=
{
"image"
:
1
},
)
# The dummy data dims should match the gridpoint with the biggest feat size
assert
mm_data
[
"image"
].
height
==
expected_size
[
0
]
assert
mm_data
[
"image"
].
width
==
expected_size
[
1
]
assert
len
(
seq_data
.
get_token_ids
())
>=
seq_len
vllm/model_executor/models/llava_next.py
View file @
722d46ed
...
@@ -33,9 +33,6 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
...
@@ -33,9 +33,6 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
from
.utils
import
(
AutoWeightsLoader
,
embed_multimodal
,
flatten_bn
,
from
.utils
import
(
AutoWeightsLoader
,
embed_multimodal
,
flatten_bn
,
init_vllm_registered_model
)
init_vllm_registered_model
)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
=
448
class
LlavaNextImagePixelInputs
(
TypedDict
):
class
LlavaNextImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
...
@@ -149,11 +146,28 @@ def get_llava_next_image_feature_size(
...
@@ -149,11 +146,28 @@ def get_llava_next_image_feature_size(
def
get_max_llava_next_image_tokens
(
ctx
:
InputContext
):
def
get_max_llava_next_image_tokens
(
ctx
:
InputContext
):
return
get_llava_next_image_feature_size
(
"""Compute the max feature size for all possible image grid pinpoints."""
ctx
.
get_hf_config
(
LlavaNextConfig
),
return
_get_pinpoint_with_largest_features
(
ctx
)[
0
]
input_height
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
input_width
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
)
def
_get_pinpoint_with_largest_features
(
ctx
:
InputContext
)
->
Tuple
[
int
,
Tuple
[
int
,
int
]]:
"""Get the grid pinpoint with the largest features & its feature size."""
hf_config
=
ctx
.
get_hf_config
(
LlavaNextConfig
)
largest_feature_size
=
0
largest_feature_pinpoint
=
None
for
(
height
,
width
)
in
hf_config
.
image_grid_pinpoints
:
feat_size
=
get_llava_next_image_feature_size
(
hf_config
,
input_height
=
height
,
input_width
=
width
,
)
if
feat_size
>
largest_feature_size
:
largest_feature_size
=
feat_size
largest_feature_pinpoint
=
(
height
,
width
)
if
not
largest_feature_size
or
largest_feature_pinpoint
is
None
:
raise
ValueError
(
"Cannot have a largest feature size of 0!"
)
return
largest_feature_size
,
largest_feature_pinpoint
def
dummy_data_for_llava_next
(
ctx
:
InputContext
,
seq_len
:
int
,
def
dummy_data_for_llava_next
(
ctx
:
InputContext
,
seq_len
:
int
,
...
@@ -162,7 +176,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
...
@@ -162,7 +176,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
vision_config
=
hf_config
.
vision_config
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
[
"image"
]
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_llava_next_image_tokens
(
ctx
)
image_feature_size
,
pinpoint
=
_get_pinpoint_with_largest_features
(
ctx
)
max_feat_height
,
max_feat_width
=
pinpoint
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
seq_data
=
dummy_seq_data_for_clip
(
seq_data
=
dummy_seq_data_for_clip
(
...
@@ -176,8 +191,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
...
@@ -176,8 +191,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
mm_data
=
dummy_image_for_clip
(
mm_data
=
dummy_image_for_clip
(
vision_config
,
vision_config
,
num_images
,
num_images
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_width_override
=
max_feat_width
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
image_height_override
=
max_feat_height
,
)
)
return
seq_data
,
mm_data
return
seq_data
,
mm_data
...
@@ -193,8 +208,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
...
@@ -193,8 +208,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
mm_data
=
dummy_image_for_siglip
(
mm_data
=
dummy_image_for_siglip
(
vision_config
,
vision_config
,
num_images
,
num_images
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_width_override
=
max_feat_width
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
image_height_override
=
max_feat_height
,
)
)
return
seq_data
,
mm_data
return
seq_data
,
mm_data
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment