Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fc679696
Unverified
Commit
fc679696
authored
Oct 06, 2025
by
Chatcharin Sangbutsarakum
Committed by
GitHub
Oct 06, 2025
Browse files
Fix `DotsOCR` tensor type (#26281)
Signed-off-by:
what_in_the_nim
<
chatcharinsang@gmail.com
>
parent
ab5e7d93
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
vllm/model_executor/models/dots_ocr.py
vllm/model_executor/models/dots_ocr.py
+6
-6
No files found.
vllm/model_executor/models/dots_ocr.py
View file @
fc679696
...
...
@@ -617,7 +617,7 @@ class DotsVisionTransformer(nn.Module):
def
device
(
self
)
->
torch
.
device
:
return
self
.
patch_embed
.
patchifier
.
proj
.
weight
.
device
def
get_pos_ids_by_grid
(
self
,
grid_thw
)
:
def
get_pos_ids_by_grid
(
self
,
grid_thw
:
list
[
list
[
int
]])
->
list
[
torch
.
Tensor
]
:
pos_ids
=
[]
for
t
,
h
,
w
in
grid_thw
:
hpos_ids
=
torch
.
arange
(
h
).
unsqueeze
(
1
).
expand
(
-
1
,
w
)
...
...
@@ -643,10 +643,10 @@ class DotsVisionTransformer(nn.Module):
return
pos_ids
def
rot_pos_emb
(
self
,
grid_thw
)
:
def
rot_pos_emb
(
self
,
grid_thw
:
list
[
list
[
int
]])
->
torch
.
Tensor
:
pos_ids
=
self
.
get_pos_ids_by_grid
(
grid_thw
)
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
max_grid_size
=
grid_thw
[:,
1
:].
max
(
)
max_grid_size
=
max
(
max
(
h
,
w
)
for
_
,
h
,
w
in
grid_thw
)
rotary_pos_emb_full
=
self
.
rotary_pos_emb
(
max_grid_size
)
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
return
rotary_pos_emb
...
...
@@ -667,13 +667,13 @@ class DotsVisionTransformer(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
grid_thw
:
list
[
list
[
int
]]
)
->
torch
.
Tensor
:
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw
)
# Convert grid_thw to tensor (always expecting list format now)
grid_thw
=
torch
.
tensor
(
grid_thw
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
long
)
hidden_states
=
hidden_states
.
to
(
self
.
dtype
)
hidden_states
=
self
.
patch_embed
(
hidden_states
,
grid_thw
)
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw
)
cu_seqlens
=
torch
.
repeat_interleave
(
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
).
cumsum
(
...
...
@@ -807,7 +807,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
rope_type
=
"rope_3d"
,
)
else
:
image_embeds
=
self
.
vision_tower
(
pixel_values
,
grid_thw
)[
image_embeds
=
self
.
vision_tower
(
pixel_values
,
grid_thw
_list
)[
:,
:
self
.
config
.
hidden_size
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment