Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
52d0cb84
Unverified
Commit
52d0cb84
authored
Sep 25, 2025
by
Jee Jee Li
Committed by
GitHub
Sep 25, 2025
Browse files
[Model] Improve DotsOCRForCausalLM (#25466)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
5c1e496a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
143 additions
and
94 deletions
+143
-94
vllm/model_executor/models/dots_ocr.py
vllm/model_executor/models/dots_ocr.py
+143
-94
No files found.
vllm/model_executor/models/dots_ocr.py
View file @
52d0cb84
...
@@ -7,11 +7,13 @@ import torch
...
@@ -7,11 +7,13 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn
import
LayerNorm
from
torch.nn
import
LayerNorm
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.models.qwen2_vl
import
Qwen2VLProcessor
from
transformers.models.qwen2_vl
import
Qwen2VLProcessor
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -19,10 +21,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -19,10 +21,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
MultiModalEmbeddings
,
from
vllm.model_executor.models.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsMultiModal
,
SupportsPP
)
SupportsPP
)
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.qwen2
import
Qwen2ForCausalLM
from
vllm.model_executor.models.qwen2
import
Qwen2ForCausalLM
from
vllm.model_executor.models.qwen2_5_vl
import
Qwen2_5_VisionAttention
from
vllm.model_executor.models.qwen2_vl
import
(
Qwen2VLDummyInputsBuilder
,
from
vllm.model_executor.models.qwen2_vl
import
(
Qwen2VLDummyInputsBuilder
,
Qwen2VLMultiModalProcessor
,
Qwen2VLMultiModalProcessor
,
Qwen2VLProcessingInfo
)
Qwen2VLProcessingInfo
)
...
@@ -38,6 +44,8 @@ from vllm.sequence import IntermediateTensors
...
@@ -38,6 +44,8 @@ from vllm.sequence import IntermediateTensors
from
vllm.transformers_utils.configs.dotsocr
import
(
DotsOCRConfig
,
from
vllm.transformers_utils.configs.dotsocr
import
(
DotsOCRConfig
,
DotsVisionConfig
)
DotsVisionConfig
)
from
.vision
import
run_dp_sharded_mrope_vision_model
IMAGE_TOKEN
=
"<|imgpad|>"
IMAGE_TOKEN
=
"<|imgpad|>"
...
@@ -181,6 +189,8 @@ class PatchMerger(nn.Module):
...
@@ -181,6 +189,8 @@ class PatchMerger(nn.Module):
context_dim
:
int
,
context_dim
:
int
,
spatial_merge_size
:
int
=
2
,
spatial_merge_size
:
int
=
2
,
pre_norm
=
"layernorm"
,
pre_norm
=
"layernorm"
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
context_dim
*
(
spatial_merge_size
**
2
)
self
.
hidden_size
=
context_dim
*
(
spatial_merge_size
**
2
)
...
@@ -189,21 +199,21 @@ class PatchMerger(nn.Module):
...
@@ -189,21 +199,21 @@ class PatchMerger(nn.Module):
self
.
ln_q
=
LayerNorm
(
context_dim
,
eps
=
1e-6
)
self
.
ln_q
=
LayerNorm
(
context_dim
,
eps
=
1e-6
)
elif
self
.
pre_norm
==
"rmsnorm"
:
elif
self
.
pre_norm
==
"rmsnorm"
:
self
.
ln_q
=
RMSNorm
(
context_dim
,
eps
=
1e-6
)
self
.
ln_q
=
RMSNorm
(
context_dim
,
eps
=
1e-6
)
else
:
print
(
"no norm in patch merger"
)
self
.
mlp
=
nn
.
Sequential
(
self
.
mlp
=
nn
.
Sequential
(
ColumnParallelLinear
(
self
.
hidden_size
,
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
bias
=
True
,
return_bias
=
False
,
return_bias
=
False
,
disable_tp
=
True
),
prefix
=
f
"
{
prefix
}
.0"
,
disable_tp
=
use_data_parallel
),
nn
.
GELU
(),
nn
.
GELU
(),
RowParallelLinear
(
self
.
hidden_size
,
RowParallelLinear
(
self
.
hidden_size
,
dim
,
dim
,
bias
=
True
,
bias
=
True
,
return_bias
=
False
,
return_bias
=
False
,
disable_tp
=
True
),
prefix
=
f
"
{
prefix
}
.2"
,
disable_tp
=
use_data_parallel
),
)
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -223,38 +233,36 @@ class DotsVisionAttention(nn.Module):
...
@@ -223,38 +233,36 @@ class DotsVisionAttention(nn.Module):
bias
:
bool
=
True
,
bias
:
bool
=
True
,
*
,
*
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
)
->
None
:
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
)
->
None
:
super
().
__init__
()
super
().
__init__
()
from
vllm.distributed
import
(
parallel_state
,
tensor_model_parallel_all_gather
)
from
vllm.distributed
import
utils
as
dist_utils
self
.
embed_dim
=
dim
self
.
embed_dim
=
dim
self
.
num_heads
=
num_heads
self
.
tp_size
=
(
1
if
use_data_parallel
else
self
.
head_dim
=
dim
//
num_heads
get_tensor_model_parallel_world_size
())
self
.
tp_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
(
0
if
use_data_parallel
else
self
.
tp_rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
get_tensor_model_parallel_rank
())
self
.
num_heads_per_partition
=
dist_utils
.
divide
(
self
.
hidden_size_per_attention_head
=
dist_utils
.
divide
(
dim
,
num_heads
)
self
.
num_attention_heads_per_partition
=
dist_utils
.
divide
(
num_heads
,
self
.
tp_size
)
num_heads
,
self
.
tp_size
)
# qkv/proj follow Qwen2-VL style; bias controlled by arg
# qkv/proj follow Qwen2-VL style; bias controlled by arg
self
.
qkv
=
QKVParallelLinear
(
hidden_size
=
dim
,
self
.
qkv
=
QKVParallelLinear
(
head_size
=
dim
//
num_heads
,
hidden_size
=
dim
,
head_size
=
self
.
hidden_size_per_attention_head
,
total_num_heads
=
num_heads
,
total_num_heads
=
num_heads
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv"
)
prefix
=
f
"
{
prefix
}
.qkv"
,
disable_tp
=
use_data_parallel
)
self
.
proj
=
RowParallelLinear
(
input_size
=
dim
,
self
.
proj
=
RowParallelLinear
(
input_size
=
dim
,
output_size
=
dim
,
output_size
=
dim
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.proj"
)
prefix
=
f
"
{
prefix
}
.proj"
,
self
.
_all_gather
=
tensor_model_parallel_all_gather
disable_tp
=
use_data_parallel
)
self
.
_split_last
=
dist_utils
.
split_tensor_along_last_dim
# Select attention backend
# Select attention backend
self
.
attn_backend
=
get_vit_attn_backend
(
self
.
head_dim
,
self
.
attn_backend
=
get_vit_attn_backend
(
torch
.
get_default_dtype
())
self
.
hidden_size_per_attention_head
,
torch
.
get_default_dtype
())
self
.
use_upstream_fa
=
False
self
.
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
...
@@ -270,19 +278,6 @@ class DotsVisionAttention(nn.Module):
...
@@ -270,19 +278,6 @@ class DotsVisionAttention(nn.Module):
_Backend
.
FLASH_ATTN
,
_Backend
.
ROCM_AITER_FA
_Backend
.
FLASH_ATTN
,
_Backend
.
ROCM_AITER_FA
}
}
def
_split_qkv
(
self
,
qkv
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
# qkv: [S, B, 3*dim]
seq_len
,
bs
,
_
=
qkv
.
shape
if
self
.
tp_size
>
1
:
qkv
=
self
.
_all_gather
(
qkv
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
2
)
if
self
.
tp_size
>
1
:
q
=
self
.
_split_last
(
q
,
num_partitions
=
self
.
tp_size
)[
self
.
tp_rank
]
k
=
self
.
_split_last
(
k
,
num_partitions
=
self
.
tp_size
)[
self
.
tp_rank
]
v
=
self
.
_split_last
(
v
,
num_partitions
=
self
.
tp_size
)[
self
.
tp_rank
]
new_shape
=
(
seq_len
,
bs
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
return
(
q
.
view
(
*
new_shape
),
k
.
view
(
*
new_shape
),
v
.
view
(
*
new_shape
))
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -295,7 +290,7 @@ class DotsVisionAttention(nn.Module):
...
@@ -295,7 +290,7 @@ class DotsVisionAttention(nn.Module):
# [S, C] -> [S, B=1, C]
# [S, C] -> [S, B=1, C]
x
=
hidden_states
.
unsqueeze
(
1
)
x
=
hidden_states
.
unsqueeze
(
1
)
x
,
_
=
self
.
qkv
(
x
)
x
,
_
=
self
.
qkv
(
x
)
q
,
k
,
v
=
self
.
_
split_qkv
(
x
)
q
,
k
,
v
=
Qwen2_5_VisionAttention
.
split_qkv
(
self
,
x
)
bs
=
q
.
shape
[
1
]
bs
=
q
.
shape
[
1
]
# [S,B,H,D] -> [B,S,H,D]
# [S,B,H,D] -> [B,S,H,D]
q
=
q
.
permute
(
1
,
0
,
2
,
3
).
contiguous
()
q
=
q
.
permute
(
1
,
0
,
2
,
3
).
contiguous
()
...
@@ -327,8 +322,9 @@ class DotsVisionAttention(nn.Module):
...
@@ -327,8 +322,9 @@ class DotsVisionAttention(nn.Module):
max_seqlen_k
=
max_seqlen
,
max_seqlen_k
=
max_seqlen
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
causal
=
False
)
causal
=
False
)
context_layer
=
output
.
view
(
bs
,
-
1
,
self
.
num_heads_per_partition
,
context_layer
=
output
.
view
(
bs
,
-
1
,
self
.
head_dim
)
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
outputs
=
[]
outputs
=
[]
for
i
in
range
(
1
,
len
(
cu_seqlens
)):
for
i
in
range
(
1
,
len
(
cu_seqlens
)):
...
@@ -368,7 +364,8 @@ class DotsSwiGLUFFN(nn.Module):
...
@@ -368,7 +364,8 @@ class DotsSwiGLUFFN(nn.Module):
config
,
config
,
*
,
*
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
hidden_features
=
config
.
intermediate_size
hidden_features
=
config
.
intermediate_size
in_features
=
config
.
embed_dim
in_features
=
config
.
embed_dim
...
@@ -380,13 +377,13 @@ class DotsSwiGLUFFN(nn.Module):
...
@@ -380,13 +377,13 @@ class DotsSwiGLUFFN(nn.Module):
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc13"
,
prefix
=
f
"
{
prefix
}
.fc13"
,
disable_tp
=
True
)
disable_tp
=
use_data_parallel
)
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
in_features
,
in_features
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
,
prefix
=
f
"
{
prefix
}
.fc2"
,
disable_tp
=
True
)
disable_tp
=
use_data_parallel
)
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -397,28 +394,36 @@ class DotsSwiGLUFFN(nn.Module):
...
@@ -397,28 +394,36 @@ class DotsSwiGLUFFN(nn.Module):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
params
=
dict
(
self
.
named_parameters
())
stacked_params_mapping
=
[
loaded
:
set
[
str
]
=
set
()
(
"fc13"
,
"fc1"
,
0
),
for
name
,
w
in
weights
:
(
"fc13"
,
"fc3"
,
1
),
# Map fc1 -> fc13 (shard 0)
]
if
name
.
startswith
(
"fc1."
):
params_dict
=
dict
(
self
.
named_parameters
())
tgt
=
name
.
replace
(
"fc1."
,
"fc13."
)
loaded_params
:
set
[
str
]
=
set
()
if
tgt
in
params
:
for
name
,
loaded_weight
in
weights
:
params
[
tgt
].
weight_loader
(
params
[
tgt
],
w
,
0
)
loaded
.
add
(
tgt
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Map fc3 -> fc13 (shard 1)
param
=
params_dict
[
name
]
if
name
.
startswith
(
"fc3."
):
weight_loader
=
param
.
weight_loader
tgt
=
name
.
replace
(
"fc3."
,
"fc13."
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
if
tgt
in
params
:
break
params
[
tgt
].
weight_loader
(
params
[
tgt
],
w
,
1
)
else
:
loaded
.
add
(
tgt
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Pass-through for fc2 and others
if
name
in
params
:
param
=
params_dict
[
name
]
params
[
name
].
weight_loader
(
params
[
name
],
w
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
loaded
.
add
(
name
)
default_weight_loader
)
return
loaded
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
DotsPatchEmbed
(
nn
.
Module
):
class
DotsPatchEmbed
(
nn
.
Module
):
...
@@ -463,25 +468,28 @@ class DotsViTPreprocessor(nn.Module):
...
@@ -463,25 +468,28 @@ class DotsViTPreprocessor(nn.Module):
class
DotsVisionBlock
(
nn
.
Module
):
class
DotsVisionBlock
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
config
,
config
,
*
,
*
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
super
().
__init__
()
self
.
attn
=
DotsVisionAttention
(
self
.
attn
=
DotsVisionAttention
(
config
,
config
,
config
.
embed_dim
,
config
.
embed_dim
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
bias
=
config
.
use_bias
,
bias
=
config
.
use_bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
use_data_parallel
=
use_data_parallel
)
self
.
norm1
=
RMSNorm
(
config
.
embed_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
norm1
=
RMSNorm
(
config
.
embed_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
mlp
=
DotsSwiGLUFFN
(
config
,
self
.
mlp
=
DotsSwiGLUFFN
(
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
prefix
=
f
"
{
prefix
}
.mlp"
,
use_data_parallel
=
use_data_parallel
)
self
.
norm2
=
RMSNorm
(
config
.
embed_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
norm2
=
RMSNorm
(
config
.
embed_dim
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
def
forward
(
self
,
...
@@ -502,7 +510,7 @@ class DotsVisionBlock(nn.Module):
...
@@ -502,7 +510,7 @@ class DotsVisionBlock(nn.Module):
return
hidden_states
return
hidden_states
class
DotsVisionTransformer
(
PreTrainedModel
):
class
DotsVisionTransformer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -512,8 +520,9 @@ class DotsVisionTransformer(PreTrainedModel):
...
@@ -512,8 +520,9 @@ class DotsVisionTransformer(PreTrainedModel):
num_hidden_layers_override
:
Optional
[
int
]
=
None
,
num_hidden_layers_override
:
Optional
[
int
]
=
None
,
require_post_norm
:
Optional
[
bool
]
=
None
,
require_post_norm
:
Optional
[
bool
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
spatial_merge_size
=
config
.
spatial_merge_size
self
.
spatial_merge_size
=
config
.
spatial_merge_size
...
@@ -526,14 +535,15 @@ class DotsVisionTransformer(PreTrainedModel):
...
@@ -526,14 +535,15 @@ class DotsVisionTransformer(PreTrainedModel):
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
out_hidden_size
=
config
.
hidden_size
# Keep blocks for compatibility with other vision towers
# Keep blocks for compatibility with other vision towers
num_layers
=
(
config
.
num_hidden_layers
if
num_hidden_layers_override
num_layers
=
(
config
.
num_hidden_layers
if
num_hidden_layers_override
is
None
else
num_hidden_layers_override
)
is
None
else
num_hidden_layers_override
)
self
.
blocks
=
nn
.
ModuleList
([
self
.
blocks
=
nn
.
ModuleList
([
DotsVisionBlock
(
config
,
DotsVisionBlock
(
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
i
}
"
)
prefix
=
f
"
{
prefix
}
.blocks.
{
i
}
"
,
use_data_parallel
=
use_data_parallel
)
for
i
in
range
(
num_layers
)
for
i
in
range
(
num_layers
)
])
])
if
require_post_norm
is
None
:
if
require_post_norm
is
None
:
...
@@ -548,6 +558,7 @@ class DotsVisionTransformer(PreTrainedModel):
...
@@ -548,6 +558,7 @@ class DotsVisionTransformer(PreTrainedModel):
dim
=
config
.
hidden_size
,
dim
=
config
.
hidden_size
,
context_dim
=
config
.
embed_dim
,
context_dim
=
config
.
embed_dim
,
spatial_merge_size
=
config
.
spatial_merge_size
,
spatial_merge_size
=
config
.
spatial_merge_size
,
use_data_parallel
=
use_data_parallel
,
)
)
@
property
@
property
...
@@ -604,7 +615,11 @@ class DotsVisionTransformer(PreTrainedModel):
...
@@ -604,7 +615,11 @@ class DotsVisionTransformer(PreTrainedModel):
return
max_seqlen
,
seqlens
return
max_seqlen
,
seqlens
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
grid_thw
:
torch
.
Tensor
)
->
torch
.
Tensor
:
grid_thw
:
list
[
list
[
int
]])
->
torch
.
Tensor
:
# Convert grid_thw to tensor (always expecting list format now)
grid_thw
=
torch
.
tensor
(
grid_thw
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
long
)
hidden_states
=
hidden_states
.
to
(
self
.
dtype
)
hidden_states
=
hidden_states
.
to
(
self
.
dtype
)
hidden_states
=
self
.
patch_embed
(
hidden_states
,
grid_thw
)
hidden_states
=
self
.
patch_embed
(
hidden_states
,
grid_thw
)
...
@@ -638,7 +653,8 @@ class DotsVisionTransformer(PreTrainedModel):
...
@@ -638,7 +653,8 @@ class DotsVisionTransformer(PreTrainedModel):
info
=
DotsOCRProcessingInfo
,
info
=
DotsOCRProcessingInfo
,
dummy_inputs
=
DotsOCRDummyInputsBuilder
,
dummy_inputs
=
DotsOCRDummyInputsBuilder
,
)
)
class
DotsOCRForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
class
DotsOCRForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsLoRA
):
hf_to_vllm_mapper
=
WeightsMapper
(
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
orig_to_new_substr
=
{
".attn.qkv_proj."
:
".attn.qkv."
,
".attn.qkv_proj."
:
".attn.qkv."
,
...
@@ -650,6 +666,21 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -650,6 +666,21 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
},
},
)
)
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
".attn.qkv"
:
[
".attn.qkv"
],
"fc13"
:
[
"fc1"
,
"fc3"
],
}
supports_encoder_tp_data
=
True
@
classmethod
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
if
modality
.
startswith
(
"image"
):
if
modality
.
startswith
(
"image"
):
...
@@ -660,19 +691,18 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -660,19 +691,18 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self
.
config
:
DotsOCRConfig
=
vllm_config
.
model_config
.
hf_config
self
.
config
:
DotsOCRConfig
=
vllm_config
.
model_config
.
hf_config
self
.
quant_config
=
vllm_config
.
quant_config
self
.
quant_config
=
vllm_config
.
quant_config
self
.
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
if
isinstance
(
self
.
config
.
vision_config
,
dict
):
if
isinstance
(
self
.
config
.
vision_config
,
dict
):
vision_config
=
DotsVisionConfig
(
**
self
.
config
.
vision_config
)
vision_config
=
DotsVisionConfig
(
**
self
.
config
.
vision_config
)
self
.
config
.
vision_config
=
vision_config
self
.
config
.
vision_config
=
vision_config
else
:
else
:
vision_config
=
self
.
config
.
vision_config
vision_config
=
self
.
config
.
vision_config
self
.
vision_tower
=
DotsVisionTransformer
(
self
.
vision_tower
=
DotsVisionTransformer
(
vision_config
,
vision_config
,
quant_config
=
self
.
quant_config
,
quant_config
=
self
.
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
use_data_parallel
=
self
.
use_data_parallel
)
self
.
language_model
:
Qwen2ForCausalLM
=
init_vllm_registered_model
(
self
.
language_model
:
Qwen2ForCausalLM
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
hf_config
=
self
.
config
,
hf_config
=
self
.
config
,
...
@@ -744,6 +774,15 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -744,6 +774,15 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
else
:
else
:
pixel_values
=
image_input
[
"pixel_values"
].
type
(
pixel_values
=
image_input
[
"pixel_values"
].
type
(
self
.
vision_tower
.
dtype
)
self
.
vision_tower
.
dtype
)
if
self
.
use_data_parallel
:
return
run_dp_sharded_mrope_vision_model
(
self
.
vision_tower
,
pixel_values
,
grid_thw_list
,
rope_type
=
"rope_3d"
,
)
else
:
image_embeds
=
self
.
vision_tower
(
image_embeds
=
self
.
vision_tower
(
pixel_values
,
grid_thw
)[:,
:
self
.
config
.
hidden_size
]
pixel_values
,
grid_thw
)[:,
:
self
.
config
.
hidden_size
]
...
@@ -822,3 +861,13 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -822,3 +861,13 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
Get the module prefix in multimodal models
"""
return
MultiModelKeys
.
from_string_field
(
language_model
=
"language_model"
,
connector
=
"vision_tower.merger"
,
tower_model
=
"vision_tower."
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment