Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9e0a147d
Unverified
Commit
9e0a147d
authored
Nov 27, 2024
by
Roger Wang
Committed by
GitHub
Nov 27, 2024
Browse files
[V1] Update interface for mistral-format Pixtral (#10703)
Signed-off-by:
Roger Wang
<
ywang@roblox.com
>
parent
418cb3b9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
19 deletions
+28
-19
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+28
-19
No files found.
vllm/model_executor/models/pixtral.py
View file @
9e0a147d
...
...
@@ -31,7 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.models.utils
import
merge_multimodal_embeddings
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.multimodal.inputs
import
NestedTensors
,
PlaceholderRange
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
consecutive_placeholder_ranges
,
resolve_visual_encoder_outputs
)
...
...
@@ -190,6 +190,25 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return
get_sampler
()
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
Optional
[
NestedTensors
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
self
.
vision_args
.
image_token_id
)
return
inputs_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -197,31 +216,21 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
"""Run forward pass for pixtral.
TODO
"""
if
intermediate_tensors
is
not
None
:
input_ids
=
None
inputs_embeds
=
None
else
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
vision_args
.
image_token_id
)
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif
inputs_embeds
is
None
:
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
vision_embeddings
)
input_ids
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment