Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3d847a31
Unverified
Commit
3d847a31
authored
Jul 27, 2025
by
Isotr0py
Committed by
GitHub
Jul 27, 2025
Browse files
[VLM] Add video support for Intern-S1 (#21671)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
5f8c9a42
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
173 additions
and
50 deletions
+173
-50
docs/models/supported_models.md
docs/models/supported_models.md
+1
-1
examples/offline_inference/vision_language.py
examples/offline_inference/vision_language.py
+5
-3
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+1
-0
vllm/model_executor/models/interns1.py
vllm/model_executor/models/interns1.py
+166
-45
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+0
-1
No files found.
docs/models/supported_models.md
View file @
3d847a31
...
@@ -593,7 +593,7 @@ Specified using `--task generate`.
...
@@ -593,7 +593,7 @@ Specified using `--task generate`.
|
`GraniteSpeechForConditionalGeneration`
| Granite Speech | T + A |
`ibm-granite/granite-speech-3.3-8b`
| ✅︎ | ✅︎ | ✅︎ |
|
`GraniteSpeechForConditionalGeneration`
| Granite Speech | T + A |
`ibm-granite/granite-speech-3.3-8b`
| ✅︎ | ✅︎ | ✅︎ |
|
`H2OVLChatModel`
| H2OVL | T + I
<sup>
E+
</sup>
|
`h2oai/h2ovl-mississippi-800m`
,
`h2oai/h2ovl-mississippi-2b`
, etc. | | ✅︎ | ✅︎ |
|
`H2OVLChatModel`
| H2OVL | T + I
<sup>
E+
</sup>
|
`h2oai/h2ovl-mississippi-800m`
,
`h2oai/h2ovl-mississippi-2b`
, etc. | | ✅︎ | ✅︎ |
|
`Idefics3ForConditionalGeneration`
| Idefics3 | T + I |
`HuggingFaceM4/Idefics3-8B-Llama3`
, etc. | ✅︎ | | ✅︎ |
|
`Idefics3ForConditionalGeneration`
| Idefics3 | T + I |
`HuggingFaceM4/Idefics3-8B-Llama3`
, etc. | ✅︎ | | ✅︎ |
|
`InternS1ForConditionalGeneration`
| Intern-S1 | T + I
<sup>
E+
</sup>
|
`internlm/Intern-S1`
, etc. | | ✅︎ | ✅︎ |
|
`InternS1ForConditionalGeneration`
| Intern-S1 | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`internlm/Intern-S1`
, etc. | | ✅︎ | ✅︎ |
|
`InternVLChatModel`
| InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I
<sup>
E+
</sup>
+ (V
<sup>
E+
</sup>
) |
`OpenGVLab/InternVL3-9B`
,
`OpenGVLab/InternVideo2_5_Chat_8B`
,
`OpenGVLab/InternVL2_5-4B`
,
`OpenGVLab/Mono-InternVL-2B`
,
`OpenGVLab/InternVL2-4B`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`InternVLChatModel`
| InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I
<sup>
E+
</sup>
+ (V
<sup>
E+
</sup>
) |
`OpenGVLab/InternVL3-9B`
,
`OpenGVLab/InternVideo2_5_Chat_8B`
,
`OpenGVLab/InternVL2_5-4B`
,
`OpenGVLab/Mono-InternVL-2B`
,
`OpenGVLab/InternVL2-4B`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`KeyeForConditionalGeneration`
| Keye-VL-8B-Preview | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Kwai-Keye/Keye-VL-8B-Preview`
| | | ✅︎ |
|
`KeyeForConditionalGeneration`
| Keye-VL-8B-Preview | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Kwai-Keye/Keye-VL-8B-Preview`
| | | ✅︎ |
|
`KimiVLForConditionalGeneration`
| Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I
<sup>
+
</sup>
|
`moonshotai/Kimi-VL-A3B-Instruct`
,
`moonshotai/Kimi-VL-A3B-Thinking`
| | | ✅︎ |
|
`KimiVLForConditionalGeneration`
| Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I
<sup>
+
</sup>
|
`moonshotai/Kimi-VL-A3B-Instruct`
,
`moonshotai/Kimi-VL-A3B-Thinking`
| | | ✅︎ |
...
...
examples/offline_inference/vision_language.py
View file @
3d847a31
...
@@ -470,8 +470,6 @@ def run_tarsier(questions: list[str], modality: str) -> ModelRequestData:
...
@@ -470,8 +470,6 @@ def run_tarsier(questions: list[str], modality: str) -> ModelRequestData:
# Intern-S1
# Intern-S1
def
run_interns1
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
def
run_interns1
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
assert
modality
==
"image"
model_name
=
"internlm/Intern-S1"
model_name
=
"internlm/Intern-S1"
engine_args
=
EngineArgs
(
engine_args
=
EngineArgs
(
...
@@ -483,7 +481,11 @@ def run_interns1(questions: list[str], modality: str) -> ModelRequestData:
...
@@ -483,7 +481,11 @@ def run_interns1(questions: list[str], modality: str) -> ModelRequestData:
enforce_eager
=
True
,
enforce_eager
=
True
,
)
)
placeholder
=
"<IMG_CONTEXT>"
if
modality
==
"image"
:
placeholder
=
"<IMG_CONTEXT>"
elif
modality
==
"video"
:
placeholder
=
"<video>"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
messages
=
[
messages
=
[
[{
"role"
:
"user"
,
"content"
:
f
"
{
placeholder
}
\n
{
question
}
"
}]
[{
"role"
:
"user"
,
"content"
:
f
"
{
placeholder
}
\n
{
question
}
"
}]
...
...
tests/models/multimodal/processing/test_common.py
View file @
3d847a31
...
@@ -278,6 +278,7 @@ def _test_processing_correctness_one(
...
@@ -278,6 +278,7 @@ def _test_processing_correctness_one(
"THUDM/GLM-4.1V-9B-Thinking"
,
"THUDM/GLM-4.1V-9B-Thinking"
,
"ibm-granite/granite-speech-3.3-2b"
,
"ibm-granite/granite-speech-3.3-2b"
,
"h2oai/h2ovl-mississippi-800m"
,
"h2oai/h2ovl-mississippi-800m"
,
"internlm/Intern-S1"
,
"OpenGVLab/InternVL2-1B"
,
"OpenGVLab/InternVL2-1B"
,
"OpenGVLab/InternVL3-1B"
,
"OpenGVLab/InternVL3-1B"
,
"HuggingFaceM4/Idefics3-8B-Llama3"
,
"HuggingFaceM4/Idefics3-8B-Llama3"
,
...
...
vllm/model_executor/models/interns1.py
View file @
3d847a31
...
@@ -9,9 +9,10 @@
...
@@ -9,9 +9,10 @@
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Literal
,
Optional
,
TypedDict
,
Union
from
typing
import
Literal
,
Optional
,
TypedDict
,
Union
import
regex
as
re
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
InternVLProcessor
,
PretrainedConfig
from
transformers
import
BatchFeature
,
InternVLProcessor
,
PretrainedConfig
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
transformers.models.got_ocr2.image_processing_got_ocr2_fast
import
(
from
transformers.models.got_ocr2.image_processing_got_ocr2_fast
import
(
GotOcr2ImageProcessorFast
)
GotOcr2ImageProcessorFast
)
...
@@ -139,13 +140,13 @@ def get_interns1_target_ratios(
...
@@ -139,13 +140,13 @@ def get_interns1_target_ratios(
class
InternS1ProcessingInfo
(
BaseProcessingInfo
):
class
InternS1ProcessingInfo
(
BaseProcessingInfo
):
"""
Basic image-only
ProcessingInfo for InternS1-style models."""
"""ProcessingInfo for InternS1-style models."""
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
InternVLProcessor
:
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
InternVLProcessor
:
return
self
.
ctx
.
get_hf_processor
(
InternVLProcessor
,
**
kwargs
)
return
self
.
ctx
.
get_hf_processor
(
InternVLProcessor
,
**
kwargs
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
return
{
"image"
:
None
,
"video"
:
None
}
def
get_num_image_tokens
(
def
get_num_image_tokens
(
self
,
self
,
...
@@ -218,16 +219,35 @@ class InternS1ProcessingInfo(BaseProcessingInfo):
...
@@ -218,16 +219,35 @@ class InternS1ProcessingInfo(BaseProcessingInfo):
processor
=
processor
.
image_processor
,
processor
=
processor
.
image_processor
,
)
)
def
get_num_frames_with_most_features
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
int
:
max_images
=
mm_counts
.
get
(
"image"
,
0
)
max_videos
=
mm_counts
.
get
(
"video"
,
0
)
processor
=
self
.
get_hf_processor
()
max_image_tokens
=
self
.
get_max_image_tokens
()
*
max_images
max_total_frames
=
(
seq_len
-
max_image_tokens
)
//
processor
.
image_seq_length
max_frames_per_video
=
max_total_frames
//
max
(
max_videos
,
1
)
return
max
(
max_frames_per_video
,
1
)
class
InternS1DummyInputsBuilder
(
BaseDummyInputsBuilder
[
InternS1ProcessingInfo
]
class
InternS1DummyInputsBuilder
(
BaseDummyInputsBuilder
[
InternS1ProcessingInfo
]
):
):
"""
Basic image-only
DummyInputsBuilder for InternS1-style models."""
"""DummyInputsBuilder for InternS1-style models."""
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
image_token
=
self
.
info
.
get_hf_processor
().
image_token
image_token
=
self
.
info
.
get_hf_processor
().
image_token
video_token
=
self
.
info
.
get_hf_processor
().
video_token
return
image_token
*
num_images
return
image_token
*
num_images
+
video_token
*
num_videos
def
get_dummy_mm_data
(
def
get_dummy_mm_data
(
self
,
self
,
...
@@ -236,13 +256,24 @@ class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]
...
@@ -236,13 +256,24 @@ class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]
)
->
MultiModalDataDict
:
)
->
MultiModalDataDict
:
target_width
,
target_height
=
\
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
self
.
info
.
get_image_size_with_most_features
()
target_num_frames
=
\
self
.
info
.
get_num_frames_with_most_features
(
seq_len
,
mm_counts
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
config
=
self
.
info
.
get_hf_config
()
image_size_h
,
image_size_w
=
config
.
vision_config
.
image_size
return
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
height
=
target_height
,
num_images
=
num_images
)
num_images
=
num_images
),
"video"
:
self
.
_get_dummy_videos
(
width
=
image_size_w
,
height
=
image_size_h
,
num_frames
=
target_num_frames
,
num_videos
=
num_videos
),
}
}
...
@@ -257,33 +288,89 @@ class InternS1MultiModalProcessor(
...
@@ -257,33 +288,89 @@ class InternS1MultiModalProcessor(
mm_kwargs
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
NestedTensors
]:
)
->
Mapping
[
str
,
NestedTensors
]:
processed_outputs
=
super
().
_call_hf_processor
(
mm_data
=
dict
(
mm_data
)
prompt
=
prompt
,
videos
=
mm_data
.
pop
(
"videos"
,
[])
mm_data
=
mm_data
,
images
=
mm_data
.
pop
(
"images"
,
[])
mm_kwargs
=
mm_kwargs
,
assert
isinstance
(
videos
,
list
)
tok_kwargs
=
tok_kwargs
,
assert
isinstance
(
images
,
list
)
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
image_token_id
=
hf_processor
.
image_token_id
tokenizer
=
hf_processor
.
tokenizer
video_token_id
=
tokenizer
.
encode
(
hf_processor
.
video_token
,
# Since there may be extra tokens in the feature placeholders,
add_special_tokens
=
False
)
# we need to pass the image token ID to the model to select the
assert
len
(
video_token_id
)
==
1
# tokens to merge from the vision encoder outputs
video_token_id
=
video_token_id
[
0
]
processed_outputs
[
"image_token_id"
]
=
torch
.
tensor
(
image_token_id
)
images
=
mm_data
.
get
(
'images'
,
None
)
prompt
=
re
.
sub
(
hf_processor
.
image_token
,
"<image_placeholder>"
,
image_processor
=
self
.
info
.
get_hf_processor
().
image_processor
prompt
)
if
images
is
not
None
:
prompt
=
re
.
sub
(
hf_processor
.
video_token
,
"<video_placeholder>"
,
image_inputs
=
image_processor
(
images
=
images
)
prompt
)
image_num_patches
=
image_inputs
.
pop
(
"num_patches"
)
if
not
isinstance
(
image_num_patches
,
list
):
image_outputs
=
{}
raise
ValueError
(
if
images
:
f
'num_patches is supposed to be list, but got '
image_pixel_values
=
[]
f
'
{
type
(
image_num_patches
)
}
'
)
for
image
in
images
:
image_num_patches
=
torch
.
tensor
(
image_num_patches
)
processed_outputs
=
super
().
_call_hf_processor
(
processed_outputs
[
'image_num_patches'
]
=
image_num_patches
prompt
=
hf_processor
.
image_token
,
mm_data
=
{
"images"
:
image
},
return
processed_outputs
mm_kwargs
=
mm_kwargs
,
tok_kwargs
=
tok_kwargs
,
)
image_pixel_values
.
append
(
processed_outputs
.
pop
(
"pixel_values"
))
input_ids
=
processed_outputs
.
pop
(
"input_ids"
)
image_placeholder
=
tokenizer
.
batch_decode
(
input_ids
)[
0
]
prompt
=
prompt
.
replace
(
"<image_placeholder>"
,
image_placeholder
,
1
)
num_patches
=
[
len
(
item
)
for
item
in
image_pixel_values
]
image_outputs
:
dict
[
str
,
NestedTensors
]
=
{
"pixel_values"
:
torch
.
concat
(
image_pixel_values
),
"image_num_patches"
:
torch
.
tensor
(
num_patches
),
"image_token_id"
:
torch
.
tensor
(
hf_processor
.
image_token_id
),
}
video_outputs
=
{}
if
videos
:
video_pixel_values
=
[]
for
video
in
videos
:
processed_outputs
=
super
().
_call_hf_processor
(
prompt
=
hf_processor
.
video_token
,
mm_data
=
{
"videos"
:
video
},
mm_kwargs
=
mm_kwargs
,
tok_kwargs
=
tok_kwargs
,
)
video_pixel_values
.
append
(
processed_outputs
.
pop
(
"pixel_values"
))
input_ids
=
processed_outputs
.
pop
(
"input_ids"
)
input_ids
[
input_ids
==
hf_processor
.
image_token_id
]
=
video_token_id
video_placeholder
=
tokenizer
.
batch_decode
(
input_ids
)[
0
]
prompt
=
prompt
.
replace
(
"<video_placeholder>"
,
video_placeholder
,
1
)
num_frames
=
[
len
(
item
)
for
item
in
video_pixel_values
]
video_outputs
:
dict
[
str
,
NestedTensors
]
=
{
"pixel_values_videos"
:
torch
.
concat
(
video_pixel_values
),
"video_num_patches"
:
torch
.
tensor
(
num_frames
),
"video_token_id"
:
torch
.
tensor
(
video_token_id
),
}
prompt
=
re
.
sub
(
"<image_placeholder>"
,
hf_processor
.
image_token
,
prompt
)
prompt
=
re
.
sub
(
"<video_placeholder>"
,
hf_processor
.
video_token
,
prompt
)
text_outputs
=
tokenizer
(
prompt
,
**
tok_kwargs
,
return_tensors
=
"pt"
)
combined_outputs
=
dict
(
**
text_outputs
,
**
image_outputs
,
**
video_outputs
,
)
return
BatchFeature
(
combined_outputs
)
def
_get_mm_fields_config
(
def
_get_mm_fields_config
(
self
,
self
,
...
@@ -292,7 +379,9 @@ class InternS1MultiModalProcessor(
...
@@ -292,7 +379,9 @@ class InternS1MultiModalProcessor(
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
image_num_patches
=
hf_inputs
.
get
(
"image_num_patches"
,
torch
.
empty
(
0
))
image_num_patches
=
hf_inputs
.
get
(
"image_num_patches"
,
torch
.
empty
(
0
))
video_num_patches
=
hf_inputs
.
get
(
"video_num_patches"
,
torch
.
empty
(
0
))
num_images
=
len
(
image_num_patches
)
num_images
=
len
(
image_num_patches
)
num_videos
=
len
(
video_num_patches
)
return
dict
(
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
...
@@ -300,6 +389,10 @@ class InternS1MultiModalProcessor(
...
@@ -300,6 +389,10 @@ class InternS1MultiModalProcessor(
image_num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_token_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
image_token_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
pixel_values_videos
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_num_patches
),
video_num_patches
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_token_id
=
MultiModalFieldConfig
.
shared
(
"video"
,
num_videos
),
)
)
def
_get_prompt_updates
(
def
_get_prompt_updates
(
...
@@ -312,32 +405,61 @@ class InternS1MultiModalProcessor(
...
@@ -312,32 +405,61 @@ class InternS1MultiModalProcessor(
img_context_token
=
hf_processor
.
image_token
img_context_token
=
hf_processor
.
image_token
start_image_token
=
hf_processor
.
start_image_token
start_image_token
=
hf_processor
.
start_image_token
end_image_token
=
hf_processor
.
end_image_token
end_image_token
=
hf_processor
.
end_image_token
video_token
=
hf_processor
.
video_token
def
get_replacement
(
item_idx
:
int
):
if
"video_num_patches"
in
out_mm_kwargs
:
video_num_patches
=
out_mm_kwargs
[
"video_num_patches"
]
assert
isinstance
(
video_num_patches
,
torch
.
Tensor
)
video_num_patches
=
video_num_patches
.
tolist
()
else
:
video_num_patches
=
[]
if
"image_num_patches"
in
out_mm_kwargs
:
image_num_patches
=
out_mm_kwargs
[
"image_num_patches"
]
assert
isinstance
(
image_num_patches
,
torch
.
Tensor
)
image_num_patches
=
image_num_patches
.
tolist
()
else
:
image_num_patches
=
[]
def
get_replacement_interns1_image
(
item_idx
:
int
):
images
=
mm_items
.
get_items
(
images
=
mm_items
.
get_items
(
"image"
,
(
ImageEmbeddingItems
,
ImageProcessorItems
))
"image"
,
(
ImageEmbeddingItems
,
ImageProcessorItems
))
if
isinstance
(
images
,
ImageEmbeddingItems
):
if
isinstance
(
images
,
ImageEmbeddingItems
):
feature_size
=
images
.
get_feature_size
(
item_idx
)
feature_size
=
images
.
get_feature_size
(
item_idx
)
else
:
else
:
image_size
=
images
.
get_image_size
(
item_idx
)
num_patches
=
image_num_patches
[
item_idx
]
feature_size
=
self
.
info
.
get_num_image_tokens
(
feature_size
=
num_patches
*
hf_processor
.
image_seq_length
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
processor
=
hf_processor
.
image_processor
,
)
repl_features
=
img_context_token
*
feature_size
repl_features
=
img_context_token
*
feature_size
repl_full
=
start_image_token
+
repl_features
+
end_image_token
repl_full
=
start_image_token
+
repl_features
+
end_image_token
return
PromptUpdateDetails
.
select_text
(
repl_full
,
return
PromptUpdateDetails
.
select_text
(
repl_full
,
img_context_token
)
img_context_token
)
def
get_replacement_interns1_video
(
item_idx
:
int
):
num_patches
=
video_num_patches
[
item_idx
]
repl_features
=
video_token
*
hf_processor
.
image_seq_length
repl_features_with_sep
=
(
start_image_token
+
repl_features
+
end_image_token
)
# num_patches is equal to num_frames
repl_full
=
'
\n
'
.
join
([
f
'Frame
{
i
+
1
}
:
{
repl_features_with_sep
}
'
for
i
in
range
(
num_patches
)
])
return
PromptUpdateDetails
.
select_text
(
repl_full
,
video_token
)
return
[
return
[
PromptReplacement
(
PromptReplacement
(
modality
=
"image"
,
modality
=
"image"
,
target
=
img_context_token
,
target
=
img_context_token
,
replacement
=
get_replacement
,
replacement
=
get_replacement_interns1_image
,
)
),
PromptReplacement
(
modality
=
"video"
,
target
=
video_token
,
replacement
=
get_replacement_interns1_video
,
),
]
]
...
@@ -514,7 +636,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -514,7 +636,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
def
_parse_and_validate_video_input
(
def
_parse_and_validate_video_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
InternS1VideoPixelInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
InternS1VideoPixelInputs
]:
pixel_values_flat_video
=
kwargs
.
pop
(
"pixel_values_
flat_
video"
,
None
)
pixel_values_flat_video
=
kwargs
.
pop
(
"pixel_values_video
s
"
,
None
)
video_num_patches
=
kwargs
.
pop
(
"video_num_patches"
,
None
)
video_num_patches
=
kwargs
.
pop
(
"video_num_patches"
,
None
)
video_embeds
=
kwargs
.
pop
(
"video_embeds"
,
None
)
video_embeds
=
kwargs
.
pop
(
"video_embeds"
,
None
)
...
@@ -595,8 +717,8 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -595,8 +717,8 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
"image_embeds"
)
and
"images"
not
in
modalities
:
"image_embeds"
)
and
"images"
not
in
modalities
:
modalities
[
"images"
]
=
self
.
_parse_and_validate_image_input
(
modalities
[
"images"
]
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
**
kwargs
)
if
input_key
in
(
"pixel_values_flat_video"
,
if
input_key
in
(
)
and
"videos"
not
in
modalities
:
"pixel_values_videos"
,
)
and
"videos"
not
in
modalities
:
modalities
[
"videos"
]
=
self
.
_parse_and_validate_video_input
(
modalities
[
"videos"
]
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
**
kwargs
)
...
@@ -614,7 +736,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -614,7 +736,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
modalities
:
if
not
modalities
:
return
[]
return
[]
return
None
# The result multimodal_embeddings is tuple of tensors, with each
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
# tensor correspoending to a multimodal data item (image or video).
...
...
vllm/model_executor/models/internvl.py
View file @
3d847a31
...
@@ -1322,7 +1322,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -1322,7 +1322,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
modalities
:
if
not
modalities
:
return
[]
return
[]
return
None
# The result multimodal_embeddings is tuple of tensors, with each
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
# tensor correspoending to a multimodal data item (image or video).
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment