Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
787cdb38
Unverified
Commit
787cdb38
authored
Aug 24, 2025
by
Benji Beck
Committed by
GitHub
Aug 25, 2025
Browse files
Migrate DonutImagePixelInputs to TensorSchema (#23509)
Signed-off-by:
Benji Beck
<
benjibeck@meta.com
>
parent
a5203d04
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
30 deletions
+19
-30
vllm/model_executor/models/donut.py
vllm/model_executor/models/donut.py
+19
-30
No files found.
vllm/model_executor/models/donut.py
View file @
787cdb38
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Literal
,
Optional
,
TypedDict
,
Union
from
typing
import
Annotated
,
Literal
,
Optional
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
...
@@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
PromptIndexTargets
,
PromptInsertion
,
PromptIndexTargets
,
PromptInsertion
,
PromptUpdate
)
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
class
MBartDecoderWrapper
(
nn
.
Module
):
class
MBartDecoderWrapper
(
nn
.
Module
):
...
@@ -132,10 +133,16 @@ class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only):
...
@@ -132,10 +133,16 @@ class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only):
return
loaded_params
return
loaded_params
class
DonutImagePixelInputs
(
TypedDict
):
class
DonutImagePixelInputs
(
TensorSchema
):
"""
Dimensions:
- b: Batch size
- c: Number of channels (3)
- h: Height
- w: Width
"""
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
data
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"b"
,
3
,
"h"
,
"w"
)]
"""Shape: (batch_size, num_channel, height, width)"""
class
DonutProcessingInfo
(
BaseProcessingInfo
):
class
DonutProcessingInfo
(
BaseProcessingInfo
):
...
@@ -275,27 +282,6 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -275,27 +282,6 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
)
)
self
.
pad_token_id
=
config
.
pad_token_id
self
.
pad_token_id
=
config
.
pad_token_id
def
_validate_pixel_values
(
self
,
data
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
)
->
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
# size = self.processor_config["size"]
h
,
w
=
self
.
config
.
encoder
.
image_size
expected_dims
=
(
3
,
h
,
w
)
def
_validate_shape
(
d
:
torch
.
Tensor
):
actual_dims
=
tuple
(
d
.
shape
)
if
actual_dims
!=
expected_dims
:
raise
ValueError
(
"The expected shape of pixel values per batch "
f
"is
{
expected_dims
}
. You supplied
{
actual_dims
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
):
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
):
pixel_values
:
Optional
[
Union
[
list
[
list
[
torch
.
Tensor
]],
pixel_values
:
Optional
[
Union
[
list
[
list
[
torch
.
Tensor
]],
list
[
torch
.
Tensor
],
list
[
torch
.
Tensor
],
...
@@ -314,11 +300,14 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -314,11 +300,14 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
"Both pixel values and image embeds are provided."
)
"Both pixel values and image embeds are provided."
)
if
pixel_values
is
not
None
:
if
pixel_values
is
not
None
:
return
DonutImagePixelInputs
(
h
,
w
=
self
.
config
.
encoder
.
image_size
type
=
"pixel_values"
,
return
DonutImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
data
=
flatten_bn
(
pixel_values
,
flatten_bn
(
pixel_values
,
concat
=
True
)),
concat
=
True
),
)
resolve_bindings
=
{
"h"
:
h
,
"w"
:
w
,
})
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
raise
NotImplementedError
raise
NotImplementedError
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment