Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a01e0018
Unverified
Commit
a01e0018
authored
Aug 13, 2025
by
Duc-Viet Hoang
Committed by
GitHub
Aug 13, 2025
Browse files
[Bugfix] Fix Nemotron VL image processing (#22739)
Co-authored-by:
ducviet00-h2
<
viet.d.hoang@h2corporation.jp
>
parent
9e7e5baa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
190 additions
and
4 deletions
+190
-4
tests/models/multimodal/processing/test_nemotron_vl.py
tests/models/multimodal/processing/test_nemotron_vl.py
+4
-4
vllm/model_executor/models/nemotron_vl.py
vllm/model_executor/models/nemotron_vl.py
+186
-0
No files found.
tests/models/multimodal/processing/test_nemotron_vl.py
View file @
a01e0018
...
...
@@ -23,15 +23,15 @@ def _get_expected_num_patches(
min_num
:
int
,
max_num
:
int
,
):
from
vllm.model_executor.models.
intern
vl
import
(
calculate_
intern
vl_targets
,
get_
intern
vl_target_ratios
)
from
vllm.model_executor.models.
nemotron_
vl
import
(
calculate_
nemotron_
vl_targets
,
get_
nemotron_
vl_target_ratios
)
width
,
height
=
image
.
size
blocks
,
_
,
_
=
calculate_
intern
vl_targets
(
blocks
,
_
,
_
=
calculate_
nemotron_
vl_targets
(
orig_width
=
width
,
orig_height
=
height
,
target_ratios
=
get_
intern
vl_target_ratios
(
target_ratios
=
get_
nemotron_
vl_target_ratios
(
min_num
,
max_num
,
),
...
...
vllm/model_executor/models/nemotron_vl.py
View file @
a01e0018
...
...
@@ -13,6 +13,7 @@ from typing import Optional
import
torch
import
torch.nn
as
nn
import
torchvision.transforms
as
T
from
PIL
import
Image
from
transformers
import
AutoModel
,
PretrainedConfig
from
transformers.image_processing_utils_fast
import
BaseImageProcessorFast
...
...
@@ -27,6 +28,7 @@ from vllm.model_executor.models.internvl import (
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
convert_image_mode
from
vllm.multimodal.inputs
import
NestedTensors
from
vllm.multimodal.processing
import
PromptUpdateDetails
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -44,6 +46,146 @@ IMG_END = '</img>'
IMG_CONTEXT
=
'<image>'
def
build_transform
(
input_size
:
int
):
return
T
.
Compose
([
T
.
Lambda
(
lambda
img
:
convert_image_mode
(
img
,
'RGB'
)),
T
.
Resize
((
input_size
,
input_size
),
interpolation
=
T
.
InterpolationMode
.
BICUBIC
),
T
.
ToTensor
(),
])
# adapted from https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1
def
find_closest_aspect_ratio
(
aspect_ratio
:
float
,
target_ratios
:
list
[
tuple
[
int
,
int
]],
*
,
width
:
int
,
height
:
int
,
image_size
:
int
,
)
->
tuple
[
int
,
int
]:
best_factor
=
float
(
'-inf'
)
best_ratio
=
(
1
,
1
)
area
=
width
*
height
for
rw
,
rh
in
target_ratios
:
target_aspect_ratio
=
rw
/
rh
size_factor
=
min
((
rw
*
rh
*
image_size
*
image_size
)
/
area
,
0.6
)
ratio_closeness
=
min
(
target_aspect_ratio
/
aspect_ratio
,
aspect_ratio
/
target_aspect_ratio
)
factor
=
size_factor
*
ratio_closeness
if
factor
>
best_factor
:
best_factor
=
factor
best_ratio
=
(
rw
,
rh
)
return
best_ratio
def
calculate_nemotron_vl_targets
(
*
,
orig_width
:
int
,
orig_height
:
int
,
target_ratios
:
list
[
tuple
[
int
,
int
]],
image_size
:
int
,
use_thumbnail
:
bool
,
)
->
tuple
[
int
,
int
,
int
]:
aspect_ratio
=
orig_width
/
orig_height
# find the closest aspect ratio to the target
target_aspect_ratio
=
find_closest_aspect_ratio
(
aspect_ratio
,
target_ratios
,
width
=
orig_width
,
height
=
orig_height
,
image_size
=
image_size
,
)
# calculate the target width and height
target_width
=
image_size
*
target_aspect_ratio
[
0
]
target_height
=
image_size
*
target_aspect_ratio
[
1
]
blocks
=
target_aspect_ratio
[
0
]
*
target_aspect_ratio
[
1
]
# add thumbnail image if num_blocks != 1
if
use_thumbnail
and
blocks
!=
1
:
blocks
+=
1
return
blocks
,
target_width
,
target_height
def
dynamic_preprocess_nemotron_vl
(
image
:
Image
.
Image
,
*
,
target_ratios
:
list
[
tuple
[
int
,
int
]],
image_size
:
int
,
use_thumbnail
:
bool
,
)
->
list
[
Image
.
Image
]:
orig_width
,
orig_height
=
image
.
size
# calculate the number of blocks without thumbnail
blocks
,
target_width
,
target_height
=
calculate_nemotron_vl_targets
(
orig_width
=
orig_width
,
orig_height
=
orig_height
,
target_ratios
=
target_ratios
,
image_size
=
image_size
,
use_thumbnail
=
False
,
)
# resize the image
resized_img
=
image
.
resize
((
target_width
,
target_height
))
processed_images
=
[]
for
i
in
range
(
blocks
):
box
=
((
i
%
(
target_width
//
image_size
))
*
image_size
,
(
i
//
(
target_width
//
image_size
))
*
image_size
,
((
i
%
(
target_width
//
image_size
))
+
1
)
*
image_size
,
((
i
//
(
target_width
//
image_size
))
+
1
)
*
image_size
)
# split the image
split_img
=
resized_img
.
crop
(
box
)
processed_images
.
append
(
split_img
)
assert
len
(
processed_images
)
==
blocks
if
use_thumbnail
and
len
(
processed_images
)
!=
1
:
thumbnail_img
=
image
.
resize
((
image_size
,
image_size
))
processed_images
.
append
(
thumbnail_img
)
return
processed_images
def
get_nemotron_vl_target_ratios
(
min_num
:
int
,
max_num
:
int
,
)
->
list
[
tuple
[
int
,
int
]]:
target_ratios
=
{(
i
,
j
)
for
n
in
range
(
min_num
,
max_num
+
1
)
for
i
in
range
(
1
,
n
+
1
)
for
j
in
range
(
1
,
n
+
1
)
if
min_num
<=
i
*
j
<=
max_num
}
return
sorted
(
target_ratios
,
key
=
lambda
x
:
x
[
0
]
*
x
[
1
])
def
image_to_pixel_values_nemotron_vl
(
image
:
Image
.
Image
,
*
,
input_size
:
int
,
min_num
:
int
,
max_num
:
int
,
use_thumbnail
:
bool
,
)
->
torch
.
Tensor
:
target_ratios
=
get_nemotron_vl_target_ratios
(
min_num
,
max_num
)
transform
=
build_transform
(
input_size
=
input_size
)
images
=
dynamic_preprocess_nemotron_vl
(
image
,
target_ratios
=
target_ratios
,
image_size
=
input_size
,
use_thumbnail
=
use_thumbnail
,
)
pixel_values
=
torch
.
stack
([
transform
(
image
)
for
image
in
images
])
return
pixel_values
class
NemotronVLProcessor
(
InternVLProcessor
):
def
__init__
(
...
...
@@ -87,6 +229,50 @@ class NemotronVLProcessor(InternVLProcessor):
def
image_token_id
(
self
)
->
int
:
return
self
.
tokenizer
.
get_vocab
()[
IMG_CONTEXT
]
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
target_ratios
=
self
.
resolve_target_ratios
(
use_thumbnail
=
False
,
# Applied in calculate_targets
)
num_patches
,
_
,
_
=
calculate_nemotron_vl_targets
(
orig_width
=
image_width
,
orig_height
=
image_height
,
image_size
=
self
.
image_size
,
target_ratios
=
target_ratios
,
use_thumbnail
=
self
.
use_thumbnail
,
)
return
num_patches
*
self
.
num_image_token
def
_images_to_pixel_values_lst
(
self
,
images
:
list
[
Image
.
Image
],
min_dynamic_patch
:
Optional
[
int
]
=
None
,
max_dynamic_patch
:
Optional
[
int
]
=
None
,
dynamic_image_size
:
Optional
[
bool
]
=
None
,
)
->
list
[
torch
.
Tensor
]:
min_num
,
max_num
=
self
.
resolve_min_max_num
(
min_dynamic_patch
=
min_dynamic_patch
,
max_dynamic_patch
=
max_dynamic_patch
,
dynamic_image_size
=
dynamic_image_size
,
use_thumbnail
=
False
,
# Applied in image_to_pixel_values
)
return
[
image_to_pixel_values_nemotron_vl
(
image
,
input_size
=
self
.
image_size
,
min_num
=
min_num
,
max_num
=
max_num
,
use_thumbnail
=
self
.
use_thumbnail
,
)
for
image
in
images
]
def
_preprocess_image
(
self
,
text
:
list
[
str
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment