Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
8e55a526
Commit
8e55a526
authored
May 27, 2025
by
Jin Zhen Jiang
Browse files
feat: add mineru-vlm backend.
parent
6f8a9610
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1954 additions
and
0 deletions
+1954
-0
mineru/model/vlm_hf_model/image_processing_mineru2.py
mineru/model/vlm_hf_model/image_processing_mineru2.py
+269
-0
mineru/model/vlm_hf_model/modeling_mineru2.py
mineru/model/vlm_hf_model/modeling_mineru2.py
+445
-0
mineru/model/vlm_sglang_model/__init__.py
mineru/model/vlm_sglang_model/__init__.py
+21
-0
mineru/model/vlm_sglang_model/engine.py
mineru/model/vlm_sglang_model/engine.py
+264
-0
mineru/model/vlm_sglang_model/image_processor.py
mineru/model/vlm_sglang_model/image_processor.py
+217
-0
mineru/model/vlm_sglang_model/logit_processor.py
mineru/model/vlm_sglang_model/logit_processor.py
+90
-0
mineru/model/vlm_sglang_model/model.py
mineru/model/vlm_sglang_model/model.py
+448
-0
mineru/model/vlm_sglang_model/server.py
mineru/model/vlm_sglang_model/server.py
+43
-0
mineru/utils/pdf_reader.py
mineru/utils/pdf_reader.py
+98
-0
mineru/utils/run_async.py
mineru/utils/run_async.py
+52
-0
pyproject.toml
pyproject.toml
+7
-0
No files found.
mineru/model/vlm_hf_model/image_processing_mineru2.py
0 → 100644
View file @
8e55a526
import
ast
import
math
import
re
from
functools
import
partial
,
reduce
from
typing
import
Dict
,
Optional
,
Union
import
numpy
as
np
import
torch
from
PIL
import
Image
from
transformers.image_processing_utils
import
(
BaseImageProcessor
,
BatchFeature
,
get_size_dict
,
)
from
transformers.image_transforms
import
(
convert_to_rgb
,
normalize
,
rescale
,
resize
,
to_channel_dimension_format
,
)
from
transformers.image_utils
import
(
ChannelDimension
,
PILImageResampling
,
to_numpy_array
,
)
from
transformers.utils
import
TensorType
def
select_best_resolution
(
original_size
:
tuple
,
possible_resolutions
:
list
)
->
tuple
:
original_width
,
original_height
=
original_size
best_fit
=
(
0
,
0
)
max_effective_resolution
=
0
min_wasted_resolution
=
float
(
"inf"
)
for
width
,
height
in
possible_resolutions
:
scale
=
min
(
width
/
original_width
,
height
/
original_height
)
downscaled_width
,
downscaled_height
=
int
(
original_width
*
scale
),
int
(
original_height
*
scale
)
effective_resolution
=
min
(
downscaled_width
*
downscaled_height
,
original_width
*
original_height
)
wasted_resolution
=
(
width
*
height
)
-
effective_resolution
if
effective_resolution
>
max_effective_resolution
or
(
effective_resolution
==
max_effective_resolution
and
wasted_resolution
<
min_wasted_resolution
):
max_effective_resolution
=
effective_resolution
min_wasted_resolution
=
wasted_resolution
best_fit
=
(
width
,
height
)
return
best_fit
def
divide_to_patches
(
image
,
patch_size
):
patches
=
[]
width
,
height
=
image
.
size
for
i
in
range
(
0
,
height
,
patch_size
):
for
j
in
range
(
0
,
width
,
patch_size
):
box
=
(
j
,
i
,
j
+
patch_size
,
i
+
patch_size
)
patch
=
image
.
crop
(
box
)
patches
.
append
(
patch
)
return
patches
def
expand2square
(
pil_img
,
background_color
):
width
,
height
=
pil_img
.
size
if
width
==
height
:
return
pil_img
if
pil_img
.
mode
==
"L"
:
pil_img
=
pil_img
.
convert
(
"RGB"
)
if
width
>
height
:
result
=
Image
.
new
(
pil_img
.
mode
,
(
width
,
width
),
background_color
)
result
.
paste
(
pil_img
,
(
0
,
(
width
-
height
)
//
2
))
return
result
else
:
result
=
Image
.
new
(
pil_img
.
mode
,
(
height
,
height
),
background_color
)
result
.
paste
(
pil_img
,
((
height
-
width
)
//
2
,
0
))
return
result
def
get_anyres_image_grid_shape
(
image_size
,
grid_pinpoints
,
patch_size
):
if
isinstance
(
grid_pinpoints
,
str
)
and
"x"
in
grid_pinpoints
:
assert
patch_size
in
[
224
,
336
,
384
,
448
,
512
],
"patch_size should be in [224, 336, 384, 448, 512]"
matches
=
re
.
findall
(
r
"\((\d+)x(\d+)\)"
,
grid_pinpoints
)
range_start
=
tuple
(
map
(
int
,
matches
[
0
]))
range_end
=
tuple
(
map
(
int
,
matches
[
-
1
]))
grid_pinpoints
=
[
(
i
,
j
)
for
i
in
range
(
range_start
[
0
],
range_end
[
0
]
+
1
)
for
j
in
range
(
range_start
[
1
],
range_end
[
1
]
+
1
)
]
grid_pinpoints
=
[[
dim
*
patch_size
for
dim
in
pair
]
for
pair
in
grid_pinpoints
]
if
type
(
grid_pinpoints
)
is
list
:
possible_resolutions
=
grid_pinpoints
else
:
possible_resolutions
=
ast
.
literal_eval
(
grid_pinpoints
)
# type: ignore
width
,
height
=
select_best_resolution
(
image_size
,
possible_resolutions
)
return
width
//
patch_size
,
height
//
patch_size
# This functions is not used.
def
resize_and_pad_image
(
image
,
target_resolution
):
original_width
,
original_height
=
image
.
size
target_width
,
target_height
=
target_resolution
scale_w
=
target_width
/
original_width
scale_h
=
target_height
/
original_height
if
scale_w
<
scale_h
:
new_width
=
target_width
new_height
=
min
(
math
.
ceil
(
original_height
*
scale_w
),
target_height
)
else
:
new_height
=
target_height
new_width
=
min
(
math
.
ceil
(
original_width
*
scale_h
),
target_width
)
# Resize the image
resized_image
=
image
.
resize
((
new_width
,
new_height
))
new_image
=
Image
.
new
(
"RGB"
,
(
target_width
,
target_height
),
(
0
,
0
,
0
))
paste_x
=
(
target_width
-
new_width
)
//
2
paste_y
=
(
target_height
-
new_height
)
//
2
new_image
.
paste
(
resized_image
,
(
paste_x
,
paste_y
))
return
new_image
# DIFFERENT from sglang.srt.mm_utils.process_anyres_image
def
process_anyres_image
(
image
,
processor
,
grid_pinpoints
):
if
isinstance
(
grid_pinpoints
,
str
)
and
"x"
in
grid_pinpoints
:
patch_size
=
processor
.
crop_size
[
"height"
]
assert
patch_size
in
[
224
,
336
,
384
,
448
,
512
],
"patch_size should be in [224, 336, 384, 448, 512]"
matches
=
re
.
findall
(
r
"\((\d+)x(\d+)\)"
,
grid_pinpoints
)
range_start
=
tuple
(
map
(
int
,
matches
[
0
]))
range_end
=
tuple
(
map
(
int
,
matches
[
-
1
]))
grid_pinpoints
=
[
(
i
,
j
)
for
i
in
range
(
range_start
[
0
],
range_end
[
0
]
+
1
)
for
j
in
range
(
range_start
[
1
],
range_end
[
1
]
+
1
)
]
grid_pinpoints
=
[[
dim
*
patch_size
for
dim
in
pair
]
for
pair
in
grid_pinpoints
]
if
type
(
grid_pinpoints
)
is
list
:
possible_resolutions
=
grid_pinpoints
else
:
possible_resolutions
=
ast
.
literal_eval
(
grid_pinpoints
)
# type: ignore
best_resolution
=
select_best_resolution
(
image
.
size
,
possible_resolutions
)
# image_padded = resize_and_pad_image(image, best_resolution)
image_padded
=
image
.
resize
(
best_resolution
)
patches
=
divide_to_patches
(
image_padded
,
processor
.
crop_size
[
"height"
])
image_original_resize
=
image
.
resize
((
processor
.
crop_size
[
"height"
],
processor
.
crop_size
[
"height"
]))
image_patches
=
[
image_original_resize
]
+
patches
image_patches
=
[
processor
.
preprocess
(
image_patch
,
return_tensors
=
"pt"
)[
"pixel_values"
][
0
]
for
image_patch
in
image_patches
]
return
torch
.
stack
(
image_patches
,
dim
=
0
)
def
process_images
(
images
,
image_processor
,
model_cfg
):
image_aspect_ratio
=
getattr
(
model_cfg
,
"image_aspect_ratio"
,
""
)
new_images
=
[]
if
image_aspect_ratio
==
"pad"
:
for
image
in
images
:
image
=
expand2square
(
image
,
tuple
(
int
(
x
*
255
)
for
x
in
image_processor
.
image_mean
))
image
=
image_processor
.
preprocess
(
image
,
return_tensors
=
"pt"
)[
"pixel_values"
][
0
]
new_images
.
append
(
image
)
elif
image_aspect_ratio
==
"anyres"
or
"anyres_max"
in
image_aspect_ratio
:
for
image
in
images
:
image
=
process_anyres_image
(
image
,
image_processor
,
model_cfg
.
image_grid_pinpoints
)
new_images
.
append
(
image
)
else
:
return
image_processor
(
images
,
return_tensors
=
"pt"
)[
"pixel_values"
]
if
all
(
x
.
shape
==
new_images
[
0
].
shape
for
x
in
new_images
):
new_images
=
torch
.
stack
(
new_images
,
dim
=
0
)
return
new_images
class
Mineru2ImageProcessor
(
BaseImageProcessor
):
model_input_names
=
[
"pixel_values"
]
def
__init__
(
self
,
image_mean
=
(
0.5
,
0.5
,
0.5
),
image_std
=
(
0.5
,
0.5
,
0.5
),
size
=
(
384
,
384
),
crop_size
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
resample
=
PILImageResampling
.
BICUBIC
,
rescale_factor
=
1
/
255
,
data_format
=
ChannelDimension
.
FIRST
,
image_aspect_ratio
:
Optional
[
str
]
=
None
,
image_grid_pinpoints
:
Optional
[
list
]
=
None
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
**
kwargs
)
crop_size
=
crop_size
if
crop_size
is
not
None
else
{
"height"
:
384
,
"width"
:
384
}
crop_size
=
get_size_dict
(
crop_size
,
default_to_square
=
True
,
param_name
=
"crop_size"
)
self
.
image_mean
=
image_mean
self
.
image_std
=
image_std
self
.
size
=
size
self
.
resample
=
resample
self
.
rescale_factor
=
rescale_factor
self
.
data_format
=
data_format
self
.
crop_size
=
crop_size
self
.
image_aspect_ratio
=
image_aspect_ratio
self
.
image_grid_pinpoints
=
image_grid_pinpoints
self
.
in_e2e_processing
=
False
def
_preprocess
(
self
,
images
):
if
isinstance
(
images
,
Image
.
Image
):
images
=
[
images
]
else
:
# to adapt video data
images
=
[
to_numpy_array
(
image
)
for
image
in
images
]
assert
isinstance
(
images
,
list
)
transforms
=
[
convert_to_rgb
,
to_numpy_array
,
partial
(
resize
,
size
=
self
.
size
,
resample
=
self
.
resample
,
data_format
=
self
.
data_format
),
partial
(
rescale
,
scale
=
self
.
rescale_factor
,
data_format
=
self
.
data_format
),
partial
(
normalize
,
mean
=
self
.
image_mean
,
std
=
self
.
image_std
,
data_format
=
self
.
data_format
),
partial
(
to_channel_dimension_format
,
channel_dim
=
self
.
data_format
,
input_channel_dim
=
self
.
data_format
),
]
images
=
reduce
(
lambda
x
,
f
:
[
*
map
(
f
,
x
)],
transforms
,
images
)
return
{
"pixel_values"
:
images
}
def
_preprocess_end_to_end
(
self
,
images
):
image_aspect_ratio
=
self
.
image_aspect_ratio
image_grid_pinpoints
=
self
.
image_grid_pinpoints
assert
image_aspect_ratio
is
not
None
assert
image_grid_pinpoints
is
not
None
pixel_values
=
[]
if
image_aspect_ratio
==
"pad"
:
for
image
in
images
:
image
=
expand2square
(
image
,
tuple
(
int
(
x
*
255
)
for
x
in
self
.
image_mean
))
image
=
self
.
_preprocess
(
image
)[
"pixel_values"
][
0
]
pixel_values
.
append
(
image
)
elif
image_aspect_ratio
==
"anyres"
or
"anyres_max"
in
image_aspect_ratio
:
for
image
in
images
:
image
=
process_anyres_image
(
image
,
self
,
self
.
image_grid_pinpoints
)
pixel_values
.
append
(
image
.
numpy
())
else
:
pixel_values
=
self
.
_preprocess
(
images
)[
"pixel_values"
]
if
isinstance
(
pixel_values
,
list
)
and
all
(
x
.
shape
==
pixel_values
[
0
].
shape
for
x
in
pixel_values
):
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
# CAUTION: here used (height, width).
image_sizes
=
[(
image
.
height
,
image
.
width
)
for
image
in
images
]
assert
len
(
pixel_values
)
==
len
(
image_sizes
)
return
{
"pixel_values"
:
pixel_values
,
"image_sizes"
:
image_sizes
}
def
preprocess
(
self
,
images
,
return_tensors
:
Optional
[
Union
[
str
,
TensorType
]]
=
None
,
**
kwargs
,
):
if
self
.
image_aspect_ratio
is
None
or
self
.
in_e2e_processing
:
data
=
self
.
_preprocess
(
images
)
else
:
assert
self
.
image_grid_pinpoints
is
not
None
self
.
in_e2e_processing
=
True
try
:
data
=
self
.
_preprocess_end_to_end
(
images
)
finally
:
self
.
in_e2e_processing
=
False
return
BatchFeature
(
data
=
data
,
tensor_type
=
return_tensors
)
mineru/model/vlm_hf_model/modeling_mineru2.py
0 → 100644
View file @
8e55a526
import
math
import
re
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
from
transformers
import
(
Qwen2ForCausalLM
,
Qwen2Model
,
SiglipVisionConfig
,
SiglipVisionModel
,
)
from
transformers.generation.utils
import
GenerateOutput
from
transformers.modeling_outputs
import
CausalLMOutputWithPast
from
.configuration_mineru2
import
Mineru2QwenConfig
from
.image_processing_mineru2
import
Mineru2ImageProcessor
,
get_anyres_image_grid_shape
class
SiglipVisionTower
(
nn
.
Module
):
def
__init__
(
self
,
vision_tower
):
super
().
__init__
()
self
.
config
=
SiglipVisionConfig
.
from_pretrained
(
vision_tower
)
assert
isinstance
(
self
.
config
,
SiglipVisionConfig
)
self
.
config
.
num_hidden_layers
-=
1
# drop the last hidden layer
self
.
config
.
vision_use_head
=
False
self
.
vision_tower
=
SiglipVisionModel
(
self
.
config
)
self
.
vision_tower
.
requires_grad_
(
False
)
self
.
image_processor
=
Mineru2ImageProcessor
()
def
forward
(
self
,
images
):
if
type
(
images
)
is
list
:
image_features
=
[]
for
image
in
images
:
image_forward_out
=
self
.
vision_tower
(
image
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
).
unsqueeze
(
0
),
output_hidden_states
=
True
)
image_feature
=
image_forward_out
.
hidden_states
[
-
1
].
to
(
image
.
dtype
)
image_features
.
append
(
image_feature
)
else
:
image_forward_outs
=
self
.
vision_tower
(
images
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
),
output_hidden_states
=
True
)
image_features
=
image_forward_outs
.
hidden_states
[
-
1
].
to
(
images
.
dtype
)
return
image_features
@
property
def
dummy_feature
(
self
):
return
torch
.
zeros
(
1
,
self
.
hidden_size
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
@
property
def
dtype
(
self
):
for
p
in
self
.
vision_tower
.
parameters
():
return
p
.
dtype
@
property
def
device
(
self
):
for
p
in
self
.
vision_tower
.
parameters
():
return
p
.
device
@
property
def
hidden_size
(
self
):
return
self
.
config
.
hidden_size
@
property
def
num_patches
(
self
):
return
(
self
.
config
.
image_size
//
self
.
config
.
patch_size
)
**
2
@
property
def
num_patches_per_side
(
self
):
return
self
.
config
.
image_size
//
self
.
config
.
patch_size
@
property
def
image_size
(
self
):
return
self
.
config
.
image_size
def
build_vision_tower
(
config
:
Mineru2QwenConfig
):
vision_tower
=
getattr
(
config
,
"mm_vision_tower"
,
getattr
(
config
,
"vision_tower"
,
""
))
if
"siglip"
in
vision_tower
.
lower
():
return
SiglipVisionTower
(
vision_tower
)
raise
ValueError
(
f
"Unknown vision tower:
{
vision_tower
}
"
)
def
build_vision_projector
(
config
:
Mineru2QwenConfig
):
projector_type
=
getattr
(
config
,
"mm_projector_type"
,
"linear"
)
if
projector_type
==
"linear"
:
return
nn
.
Linear
(
config
.
mm_hidden_size
,
config
.
hidden_size
)
mlp_gelu_match
=
re
.
match
(
r
"^mlp(\d+)x_gelu$"
,
projector_type
)
if
mlp_gelu_match
:
mlp_depth
=
int
(
mlp_gelu_match
.
group
(
1
))
modules
=
[
nn
.
Linear
(
config
.
mm_hidden_size
,
config
.
hidden_size
)]
for
_
in
range
(
1
,
mlp_depth
):
modules
.
append
(
nn
.
GELU
())
# type: ignore
modules
.
append
(
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
))
return
nn
.
Sequential
(
*
modules
)
if
projector_type
==
"identity"
:
return
nn
.
Identity
()
raise
ValueError
(
f
"Unknown projector type:
{
projector_type
}
"
)
class
Mineru2QwenModel
(
Qwen2Model
):
config_class
=
Mineru2QwenConfig
def
__init__
(
self
,
config
:
Mineru2QwenConfig
):
super
(
Mineru2QwenModel
,
self
).
__init__
(
config
)
self
.
vision_tower
=
build_vision_tower
(
config
)
self
.
mm_projector
=
build_vision_projector
(
config
)
if
"unpad"
in
getattr
(
config
,
"mm_patch_merge_type"
,
""
):
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
hidden_size
,
dtype
=
self
.
dtype
))
class
Mineru2QwenForCausalLM
(
Qwen2ForCausalLM
):
config_class
=
Mineru2QwenConfig
def
__init__
(
self
,
config
:
Mineru2QwenConfig
):
super
(
Qwen2ForCausalLM
,
self
).
__init__
(
config
)
config
.
rope_scaling
=
None
self
.
model
=
Mineru2QwenModel
(
config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
self
.
ignore_index
=
config
.
ignore_index
self
.
image_token_index
=
config
.
image_token_index
# Initialize weights and apply final processing
self
.
post_init
()
def
get_model
(
self
):
return
self
.
model
def
encode_images
(
self
,
images
:
torch
.
Tensor
):
image_features
=
self
.
get_model
().
vision_tower
(
images
)
image_features
=
self
.
get_model
().
mm_projector
(
image_features
)
return
image_features
def
prepare_inputs_labels_for_multimodal
(
self
,
input_ids
,
position_ids
,
attention_mask
,
past_key_values
,
labels
,
images
,
image_sizes
=
None
):
vision_tower
=
self
.
get_model
().
vision_tower
if
vision_tower
is
None
or
images
is
None
or
input_ids
.
shape
[
1
]
==
1
:
return
input_ids
,
position_ids
,
attention_mask
,
past_key_values
,
None
,
labels
if
type
(
images
)
is
list
or
images
.
ndim
==
5
:
if
type
(
images
)
is
list
:
images
=
[
x
.
unsqueeze
(
0
)
if
x
.
ndim
==
3
else
x
for
x
in
images
]
concat_images
=
torch
.
cat
([
image
for
image
in
images
],
dim
=
0
)
image_features
=
self
.
encode_images
(
concat_images
)
split_sizes
=
[
image
.
shape
[
0
]
for
image
in
images
]
image_features
=
torch
.
split
(
image_features
,
split_sizes
,
dim
=
0
)
mm_patch_merge_type
=
getattr
(
self
.
config
,
"mm_patch_merge_type"
,
"flat"
)
image_aspect_ratio
=
getattr
(
self
.
config
,
"image_aspect_ratio"
,
"square"
)
if
mm_patch_merge_type
==
"flat"
:
image_features
=
[
x
.
flatten
(
0
,
1
)
for
x
in
image_features
]
elif
mm_patch_merge_type
.
startswith
(
"spatial"
):
new_image_features
=
[]
for
image_idx
,
image_feature
in
enumerate
(
image_features
):
if
image_feature
.
shape
[
0
]
>
1
:
base_image_feature
=
image_feature
[
0
]
image_feature
=
image_feature
[
1
:]
height
=
width
=
self
.
get_model
().
vision_tower
.
num_patches_per_side
assert
height
*
width
==
base_image_feature
.
shape
[
0
]
if
"anyres_max"
in
image_aspect_ratio
:
matched_anyres_max_num_patches
=
re
.
match
(
r
"square_anyres_max_(\d+)"
,
image_aspect_ratio
)
if
matched_anyres_max_num_patches
:
max_num_patches
=
int
(
matched_anyres_max_num_patches
.
group
(
1
))
if
image_aspect_ratio
==
"anyres"
or
"anyres_max"
in
image_aspect_ratio
:
num_patch_width
,
num_patch_height
=
get_anyres_image_grid_shape
(
image_sizes
[
image_idx
],
self
.
config
.
image_grid_pinpoints
,
self
.
get_model
().
vision_tower
.
config
.
image_size
,
)
image_feature
=
image_feature
.
view
(
num_patch_height
,
num_patch_width
,
height
,
width
,
-
1
)
else
:
raise
NotImplementedError
if
(
"unpad"
in
mm_patch_merge_type
and
"anyres_max"
in
image_aspect_ratio
and
matched_anyres_max_num_patches
):
unit
=
image_feature
.
shape
[
2
]
image_feature
=
image_feature
.
permute
(
4
,
0
,
2
,
1
,
3
).
contiguous
()
image_feature
=
image_feature
.
flatten
(
1
,
2
).
flatten
(
2
,
3
)
c
,
h
,
w
=
image_feature
.
shape
times
=
math
.
sqrt
(
h
*
w
/
(
max_num_patches
*
unit
**
2
))
if
times
>
1.1
:
image_feature
=
image_feature
[
None
]
image_feature
=
nn
.
functional
.
interpolate
(
image_feature
,
[
int
(
h
//
times
),
int
(
w
//
times
)],
mode
=
"bilinear"
)[
0
]
image_feature
=
torch
.
cat
(
(
image_feature
,
self
.
model
.
image_newline
[:,
None
,
None
]
.
expand
(
*
image_feature
.
shape
[:
-
1
],
1
)
.
to
(
image_feature
.
device
),
),
dim
=-
1
,
)
image_feature
=
image_feature
.
flatten
(
1
,
2
).
transpose
(
0
,
1
)
elif
"unpad"
in
mm_patch_merge_type
:
image_feature
=
image_feature
.
permute
(
4
,
0
,
2
,
1
,
3
).
contiguous
()
image_feature
=
image_feature
.
flatten
(
1
,
2
).
flatten
(
2
,
3
)
image_feature
=
torch
.
cat
(
(
image_feature
,
self
.
model
.
image_newline
[:,
None
,
None
]
.
expand
(
*
image_feature
.
shape
[:
-
1
],
1
)
.
to
(
image_feature
.
device
),
),
dim
=-
1
,
)
image_feature
=
image_feature
.
flatten
(
1
,
2
).
transpose
(
0
,
1
)
else
:
image_feature
=
image_feature
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
image_feature
=
image_feature
.
flatten
(
0
,
3
)
image_feature
=
torch
.
cat
((
base_image_feature
,
image_feature
),
dim
=
0
)
else
:
image_feature
=
image_feature
[
0
]
if
"unpad"
in
mm_patch_merge_type
:
image_feature
=
torch
.
cat
(
(
image_feature
,
self
.
model
.
image_newline
[
None
].
to
(
image_feature
.
device
)),
dim
=
0
)
new_image_features
.
append
(
image_feature
)
image_features
=
new_image_features
else
:
raise
ValueError
(
f
"Unexpected mm_patch_merge_type:
{
self
.
config
.
mm_patch_merge_type
}
"
)
else
:
image_features
=
self
.
encode_images
(
images
)
_labels
=
labels
_position_ids
=
position_ids
_attention_mask
=
attention_mask
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
,
dtype
=
torch
.
bool
)
else
:
attention_mask
=
attention_mask
.
bool
()
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
0
,
input_ids
.
shape
[
1
],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
if
labels
is
None
:
labels
=
torch
.
full_like
(
input_ids
,
self
.
ignore_index
)
# remove the padding using attention_mask -- FIXME
_input_ids
=
input_ids
input_ids
=
[
cur_input_ids
[
cur_attention_mask
]
for
cur_input_ids
,
cur_attention_mask
in
zip
(
input_ids
,
attention_mask
)]
labels
=
[
cur_labels
[
cur_attention_mask
]
for
cur_labels
,
cur_attention_mask
in
zip
(
labels
,
attention_mask
)]
new_input_embeds
=
[]
new_labels
=
[]
cur_image_idx
=
0
for
batch_idx
,
cur_input_ids
in
enumerate
(
input_ids
):
num_images
=
(
cur_input_ids
==
self
.
image_token_index
).
sum
()
if
num_images
==
0
:
cur_image_features
=
image_features
[
cur_image_idx
]
cur_input_embeds_1
=
self
.
get_model
().
embed_tokens
(
cur_input_ids
)
cur_input_embeds
=
torch
.
cat
([
cur_input_embeds_1
,
cur_image_features
[
0
:
0
]],
dim
=
0
)
new_input_embeds
.
append
(
cur_input_embeds
)
new_labels
.
append
(
labels
[
batch_idx
])
cur_image_idx
+=
1
continue
image_token_indices
=
(
[
-
1
]
+
torch
.
where
(
cur_input_ids
==
self
.
image_token_index
)[
0
].
tolist
()
+
[
cur_input_ids
.
shape
[
0
]]
)
cur_input_ids_noim
=
[]
cur_labels
=
labels
[
batch_idx
]
cur_labels_noim
=
[]
for
i
in
range
(
len
(
image_token_indices
)
-
1
):
cur_input_ids_noim
.
append
(
cur_input_ids
[
image_token_indices
[
i
]
+
1
:
image_token_indices
[
i
+
1
]])
cur_labels_noim
.
append
(
cur_labels
[
image_token_indices
[
i
]
+
1
:
image_token_indices
[
i
+
1
]])
split_sizes
=
[
x
.
shape
[
0
]
for
x
in
cur_labels_noim
]
cur_input_embeds
=
self
.
get_model
().
embed_tokens
(
torch
.
cat
(
cur_input_ids_noim
))
cur_input_embeds_no_im
=
torch
.
split
(
cur_input_embeds
,
split_sizes
,
dim
=
0
)
cur_new_input_embeds
=
[]
cur_new_labels
=
[]
for
i
in
range
(
num_images
+
1
):
cur_new_input_embeds
.
append
(
cur_input_embeds_no_im
[
i
])
cur_new_labels
.
append
(
cur_labels_noim
[
i
])
if
i
<
num_images
:
cur_image_features
=
image_features
[
cur_image_idx
]
cur_image_idx
+=
1
cur_new_input_embeds
.
append
(
cur_image_features
)
cur_new_labels
.
append
(
torch
.
full
(
(
cur_image_features
.
shape
[
0
],),
self
.
ignore_index
,
device
=
cur_labels
.
device
,
dtype
=
cur_labels
.
dtype
)
)
cur_new_input_embeds
=
[
x
.
to
(
self
.
device
)
for
x
in
cur_new_input_embeds
]
cur_new_input_embeds
=
torch
.
cat
(
cur_new_input_embeds
)
cur_new_labels
=
torch
.
cat
(
cur_new_labels
)
new_input_embeds
.
append
(
cur_new_input_embeds
)
new_labels
.
append
(
cur_new_labels
)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length
=
getattr
(
self
.
config
,
"tokenizer_model_max_length"
,
None
)
if
tokenizer_model_max_length
is
not
None
:
new_input_embeds
=
[
x
[:
tokenizer_model_max_length
]
for
x
in
new_input_embeds
]
new_labels
=
[
x
[:
tokenizer_model_max_length
]
for
x
in
new_labels
]
# Combine them
max_len
=
max
(
x
.
shape
[
0
]
for
x
in
new_input_embeds
)
batch_size
=
len
(
new_input_embeds
)
new_input_embeds_padded
=
[]
new_labels_padded
=
torch
.
full
(
(
batch_size
,
max_len
),
self
.
ignore_index
,
dtype
=
new_labels
[
0
].
dtype
,
device
=
new_labels
[
0
].
device
)
attention_mask
=
torch
.
zeros
((
batch_size
,
max_len
),
dtype
=
attention_mask
.
dtype
,
device
=
attention_mask
.
device
)
position_ids
=
torch
.
zeros
((
batch_size
,
max_len
),
dtype
=
position_ids
.
dtype
,
device
=
position_ids
.
device
)
for
i
,
(
cur_new_embed
,
cur_new_labels
)
in
enumerate
(
zip
(
new_input_embeds
,
new_labels
)):
cur_len
=
cur_new_embed
.
shape
[
0
]
if
getattr
(
self
.
config
,
"tokenizer_padding_side"
,
"right"
)
==
"left"
:
new_input_embeds_padded
.
append
(
torch
.
cat
(
(
torch
.
zeros
(
(
max_len
-
cur_len
,
cur_new_embed
.
shape
[
1
]),
dtype
=
cur_new_embed
.
dtype
,
device
=
cur_new_embed
.
device
,
),
cur_new_embed
,
),
dim
=
0
,
)
)
if
cur_len
>
0
:
new_labels_padded
[
i
,
-
cur_len
:]
=
cur_new_labels
attention_mask
[
i
,
-
cur_len
:]
=
True
position_ids
[
i
,
-
cur_len
:]
=
torch
.
arange
(
0
,
cur_len
,
dtype
=
position_ids
.
dtype
,
device
=
position_ids
.
device
)
else
:
new_input_embeds_padded
.
append
(
torch
.
cat
(
(
cur_new_embed
,
torch
.
zeros
(
(
max_len
-
cur_len
,
cur_new_embed
.
shape
[
1
]),
dtype
=
cur_new_embed
.
dtype
,
device
=
cur_new_embed
.
device
,
),
),
dim
=
0
,
)
)
if
cur_len
>
0
:
new_labels_padded
[
i
,
:
cur_len
]
=
cur_new_labels
attention_mask
[
i
,
:
cur_len
]
=
True
position_ids
[
i
,
:
cur_len
]
=
torch
.
arange
(
0
,
cur_len
,
dtype
=
position_ids
.
dtype
,
device
=
position_ids
.
device
)
new_input_embeds
=
torch
.
stack
(
new_input_embeds_padded
,
dim
=
0
)
if
_labels
is
None
:
new_labels
=
None
else
:
new_labels
=
new_labels_padded
if
_attention_mask
is
None
:
attention_mask
=
None
else
:
attention_mask
=
attention_mask
.
to
(
dtype
=
_attention_mask
.
dtype
)
if
_position_ids
is
None
:
position_ids
=
None
return
None
,
position_ids
,
attention_mask
,
past_key_values
,
new_input_embeds
,
new_labels
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
images
:
Optional
[
torch
.
FloatTensor
]
=
None
,
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
if
inputs_embeds
is
None
:
(
input_ids
,
position_ids
,
attention_mask
,
past_key_values
,
inputs_embeds
,
labels
)
=
(
self
.
prepare_inputs_labels_for_multimodal
(
input_ids
,
position_ids
,
attention_mask
,
past_key_values
,
labels
,
images
,
image_sizes
)
)
return
super
().
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
labels
=
labels
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
@
torch
.
no_grad
()
def
generate
(
self
,
inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
images
:
Optional
[
torch
.
Tensor
]
=
None
,
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
**
kwargs
,
)
->
Union
[
GenerateOutput
,
torch
.
LongTensor
]:
position_ids
=
kwargs
.
pop
(
"position_ids"
,
None
)
attention_mask
=
kwargs
.
pop
(
"attention_mask"
,
None
)
if
"inputs_embeds"
in
kwargs
:
raise
NotImplementedError
(
"`inputs_embeds` is not supported"
)
inputs
,
position_ids
,
attention_mask
,
_
,
inputs_embeds
,
_
=
self
.
prepare_inputs_labels_for_multimodal
(
inputs
,
position_ids
,
attention_mask
,
None
,
None
,
images
,
image_sizes
=
image_sizes
)
return
super
().
generate
(
position_ids
=
position_ids
,
attention_mask
=
attention_mask
,
inputs_embeds
=
inputs_embeds
,
**
kwargs
)
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
inputs_embeds
=
None
,
**
kwargs
):
images
=
kwargs
.
pop
(
"images"
,
None
)
image_sizes
=
kwargs
.
pop
(
"image_sizes"
,
None
)
inputs
=
super
().
prepare_inputs_for_generation
(
input_ids
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
**
kwargs
)
if
images
is
not
None
:
inputs
[
"images"
]
=
images
if
image_sizes
is
not
None
:
inputs
[
"image_sizes"
]
=
image_sizes
return
inputs
mineru/model/vlm_sglang_model/__init__.py
0 → 100644
View file @
8e55a526
from
sglang.srt.configs.model_config
import
multimodal_model_archs
from
sglang.srt.models.registry
import
ModelRegistry
try
:
# sglang==0.4.5.post3
from
sglang.srt.managers.multimodal_processor
import
(
PROCESSOR_MAPPING
as
PROCESSOR_MAPPING
,
)
except
ImportError
:
# sglang==0.4.4.post1
from
sglang.srt.managers.image_processor
import
(
IMAGE_PROCESSOR_MAPPING
as
PROCESSOR_MAPPING
,
)
from
..
import
vlm_hf_model
as
_
from
.image_processor
import
Mineru2ImageProcessor
from
.model
import
Mineru2QwenForCausalLM
ModelRegistry
.
models
[
Mineru2QwenForCausalLM
.
__name__
]
=
Mineru2QwenForCausalLM
PROCESSOR_MAPPING
[
Mineru2QwenForCausalLM
]
=
Mineru2ImageProcessor
multimodal_model_archs
.
append
(
Mineru2QwenForCausalLM
.
__name__
)
mineru/model/vlm_sglang_model/engine.py
0 → 100644
View file @
8e55a526
import
asyncio
import
time
from
types
import
MethodType
from
typing
import
AsyncIterator
,
Dict
,
Iterator
,
List
,
Optional
,
Union
import
fastapi
from
sglang.srt.entrypoints.engine
import
Engine
as
_Engine
from
sglang.srt.managers.io_struct
import
EmbeddingReqInput
,
GenerateReqInput
from
sglang.srt.managers.tokenizer_manager
import
(
TokenizerManager
,
dataclass_to_string_truncated
,
logger
,
)
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
...utils.run_async
import
run_async
from
.logit_processor
import
Mineru2LogitProcessor
class
BatchEngine
(
_Engine
):
"""
The engine is patched to support batch multi-modal generate, and early image preprocessing.
"""
def
__init__
(
self
,
server_args
:
ServerArgs
,
**
kwargs
):
server_args
.
enable_custom_logit_processor
=
True
super
().
__init__
(
server_args
=
server_args
,
**
kwargs
)
_patch_tokenizer_manager
(
self
.
tokenizer_manager
)
def
generate
(
self
,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
,
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
False
,
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
token_ids_logprob
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
,
lora_path
:
Optional
[
List
[
Optional
[
str
]]]
=
None
,
custom_logit_processor
:
Optional
[
Union
[
List
[
Optional
[
str
]],
str
]]
=
None
,
return_hidden_states
:
bool
=
False
,
stream
:
bool
=
False
,
)
->
Union
[
Dict
,
Iterator
[
Dict
]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
modalities_list
=
[]
# EDIT
if
isinstance
(
image_data
,
list
):
for
_
in
range
(
len
(
image_data
)):
modalities_list
.
append
([
"image"
])
elif
image_data
is
not
None
:
modalities_list
.
append
(
"image"
)
# ADD
if
custom_logit_processor
is
None
:
custom_logit_processor
=
Mineru2LogitProcessor
().
to_str
()
obj
=
GenerateReqInput
(
text
=
prompt
,
input_ids
=
input_ids
,
sampling_params
=
sampling_params
,
image_data
=
image_data
,
return_logprob
=
return_logprob
,
logprob_start_len
=
logprob_start_len
,
top_logprobs_num
=
top_logprobs_num
,
token_ids_logprob
=
token_ids_logprob
,
lora_path
=
lora_path
,
modalities
=
modalities_list
,
custom_logit_processor
=
custom_logit_processor
,
return_hidden_states
=
return_hidden_states
,
stream
=
stream
,
)
generator
=
_generate_request
(
self
.
tokenizer_manager
,
obj
,
None
)
if
stream
:
def
generator_wrapper
():
while
True
:
try
:
chunk
=
run_async
(
generator
.
__anext__
())
yield
chunk
except
StopAsyncIteration
:
break
return
generator_wrapper
()
else
:
ret
=
run_async
(
generator
.
__anext__
())
return
ret
async
def
async_generate
(
self
,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
,
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
False
,
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
token_ids_logprob
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
,
lora_path
:
Optional
[
List
[
Optional
[
str
]]]
=
None
,
custom_logit_processor
:
Optional
[
Union
[
List
[
Optional
[
str
]],
str
]]
=
None
,
return_hidden_states
:
bool
=
False
,
stream
:
bool
=
False
,
)
->
Union
[
Dict
,
AsyncIterator
[
Dict
],
Iterator
[
Dict
]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
modalities_list
=
[]
# EDIT
if
isinstance
(
image_data
,
list
):
for
_
in
range
(
len
(
image_data
)):
modalities_list
.
append
([
"image"
])
elif
image_data
is
not
None
:
modalities_list
.
append
(
"image"
)
# ADD
if
custom_logit_processor
is
None
:
custom_logit_processor
=
Mineru2LogitProcessor
().
to_str
()
obj
=
GenerateReqInput
(
text
=
prompt
,
input_ids
=
input_ids
,
sampling_params
=
sampling_params
,
image_data
=
image_data
,
return_logprob
=
return_logprob
,
logprob_start_len
=
logprob_start_len
,
top_logprobs_num
=
top_logprobs_num
,
token_ids_logprob
=
token_ids_logprob
,
lora_path
=
lora_path
,
modalities
=
modalities_list
,
custom_logit_processor
=
custom_logit_processor
,
return_hidden_states
=
return_hidden_states
,
stream
=
stream
,
)
generator
=
_generate_request
(
self
.
tokenizer_manager
,
obj
,
None
)
if
stream
is
True
:
return
generator
else
:
return
await
generator
.
__anext__
()
def
_auto_create_handle_loop
(
self
:
TokenizerManager
):
"""
patch the original `auto_create_handle_loop()` method to reset `no_create_loop`
when the event loop changes.
"""
try
:
curr_handle_loop
=
asyncio
.
get_running_loop
()
except
RuntimeError
:
curr_handle_loop
=
None
last_handle_loop
=
getattr
(
self
,
"_last_handle_loop"
,
None
)
if
last_handle_loop
!=
curr_handle_loop
:
self
.
no_create_loop
=
False
setattr
(
self
,
"_last_handle_loop"
,
curr_handle_loop
)
return
TokenizerManager
.
auto_create_handle_loop
(
self
)
def
_patch_tokenizer_manager
(
self
:
TokenizerManager
):
self
.
auto_create_handle_loop
=
MethodType
(
_auto_create_handle_loop
,
self
)
async
def
_one_request
(
self
:
TokenizerManager
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
:
Optional
[
fastapi
.
Request
],
created_time
:
Optional
[
float
],
):
tokenized_obj
=
await
self
.
_tokenize_one_request
(
obj
)
self
.
_send_one_request
(
obj
,
tokenized_obj
,
created_time
)
async
for
out
in
self
.
_wait_one_response
(
obj
,
request
):
yield
out
async
def
_handle_batch_request
(
self
:
TokenizerManager
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
created_time
:
Optional
[
float
]
=
None
,
):
batch_size
=
obj
.
batch_size
generators
=
[]
rids
=
[]
if
getattr
(
obj
,
"parallel_sample_num"
,
1
)
!=
1
:
raise
Exception
(
"parallel_sample_num != 1 is not supported in this patched code."
)
# Send all requests
for
i
in
range
(
batch_size
):
tmp_obj
=
obj
[
i
]
generators
.
append
(
_one_request
(
self
,
tmp_obj
,
request
,
created_time
))
rids
.
append
(
tmp_obj
.
rid
)
# Wait for all requests
is_stream
=
hasattr
(
obj
,
"stream"
)
and
obj
.
stream
if
not
is_stream
:
outputs
=
await
asyncio
.
gather
(
*
(
gen
.
__anext__
()
for
gen
in
generators
))
yield
outputs
else
:
rid_to_index
=
{
rid
:
i
for
i
,
rid
in
enumerate
(
rids
)}
task_map
=
{
asyncio
.
create_task
(
gen
.
__anext__
()):
gen
for
gen
in
generators
}
while
task_map
:
done
,
_
=
await
asyncio
.
wait
(
task_map
.
keys
(),
return_when
=
asyncio
.
FIRST_COMPLETED
)
for
task
in
done
:
gen
=
task_map
.
pop
(
task
)
try
:
result
=
task
.
result
()
result
[
"index"
]
=
rid_to_index
[
result
[
"meta_info"
][
"id"
]]
yield
result
new_task
=
asyncio
.
create_task
(
gen
.
__anext__
())
task_map
[
new_task
]
=
gen
except
StopAsyncIteration
:
pass
async
def
_generate_request
(
self
:
TokenizerManager
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
created_time
=
time
.
time
()
self
.
auto_create_handle_loop
()
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
raise
ValueError
(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
obj
.
normalize_batch_and_arguments
()
if
self
.
log_requests
:
max_length
,
skip_names
,
_
=
self
.
log_request_metadata
logger
.
info
(
f
"Receive: obj=
{
dataclass_to_string_truncated
(
obj
,
max_length
,
skip_names
=
skip_names
)
}
"
)
async
with
self
.
model_update_lock
.
reader_lock
:
is_single
=
obj
.
is_single
if
is_single
:
tokenized_obj
=
await
self
.
_tokenize_one_request
(
obj
)
self
.
_send_one_request
(
obj
,
tokenized_obj
,
created_time
)
async
for
response
in
self
.
_wait_one_response
(
obj
,
request
):
yield
response
else
:
async
for
response
in
_handle_batch_request
(
self
,
obj
,
request
,
created_time
):
yield
response
mineru/model/vlm_sglang_model/image_processor.py
0 → 100644
View file @
8e55a526
import
ast
import
asyncio
import
re
from
typing
import
List
,
Optional
,
Union
import
numpy
as
np
try
:
# sglang==0.4.5.post3
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
as
BaseProcessor
,
)
get_global_processor
=
None
except
ImportError
:
# sglang==0.4.4.post1
from
sglang.srt.managers.image_processors.base_image_processor
import
(
BaseImageProcessor
as
BaseProcessor
,
get_global_processor
,
)
from
sglang.srt.mm_utils
import
divide_to_patches
,
expand2square
,
select_best_resolution
from
sglang.srt.utils
import
load_image
,
logger
from
sglang.utils
import
get_exception_traceback
from
.model
import
Mineru2QwenForCausalLM
# image_best_res is only resized (not padded).
def
process_anyres_image
(
image
,
processor
,
grid_pinpoints
):
if
isinstance
(
grid_pinpoints
,
str
)
and
"x"
in
grid_pinpoints
:
patch_size
=
processor
.
crop_size
[
"height"
]
assert
patch_size
in
[
224
,
336
,
384
,
448
,
512
],
"patch_size should be in [224, 336, 384, 448, 512]"
matches
=
re
.
findall
(
r
"\((\d+)x(\d+)\)"
,
grid_pinpoints
)
range_start
=
tuple
(
map
(
int
,
matches
[
0
]))
range_end
=
tuple
(
map
(
int
,
matches
[
-
1
]))
grid_pinpoints
=
[
(
i
,
j
)
for
i
in
range
(
range_start
[
0
],
range_end
[
0
]
+
1
)
for
j
in
range
(
range_start
[
1
],
range_end
[
1
]
+
1
)
]
grid_pinpoints
=
[[
dim
*
patch_size
for
dim
in
pair
]
for
pair
in
grid_pinpoints
]
if
type
(
grid_pinpoints
)
is
list
:
possible_resolutions
=
grid_pinpoints
else
:
possible_resolutions
=
ast
.
literal_eval
(
grid_pinpoints
)
best_resolution
=
select_best_resolution
(
image
.
size
,
possible_resolutions
)
image_best_res
=
image
.
resize
(
best_resolution
)
# <<<<<<< Here changed
patches
=
divide_to_patches
(
image_best_res
,
processor
.
crop_size
[
"height"
])
image_original_resize
=
image
.
resize
((
processor
.
crop_size
[
"height"
],
processor
.
crop_size
[
"height"
]))
image_patches
=
[
image_original_resize
]
+
patches
image_patches
=
[
processor
.
preprocess
(
image_patch
)[
"pixel_values"
][
0
]
for
image_patch
in
image_patches
]
return
np
.
stack
(
image_patches
,
axis
=
0
)
class
Mineru2ImageProcessor
(
BaseProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
image_data
:
Union
[
str
,
bytes
],
image_aspect_ratio
:
Optional
[
str
]
=
None
,
image_grid_pinpoints
:
Optional
[
str
]
=
None
,
image_processor
=
None
,
):
if
image_processor
is
None
:
assert
get_global_processor
is
not
None
image_processor
=
get_global_processor
().
image_processor
try
:
image
,
image_size
=
load_image
(
image_data
)
if
image_size
is
not
None
:
# It is a video with multiple images
image_hash
=
hash
(
image_data
)
pixel_values
=
image_processor
(
image
)[
"pixel_values"
]
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
return
pixel_values
,
image_hash
,
image_size
else
:
# It is an image
image_hash
=
hash
(
image_data
)
if
image_aspect_ratio
==
"pad"
:
image
=
expand2square
(
image
,
tuple
(
int
(
x
*
255
)
for
x
in
image_processor
.
image_mean
),
)
pixel_values
=
image_processor
(
image
.
convert
(
"RGB"
))[
"pixel_values"
][
0
]
elif
image_aspect_ratio
==
"anyres"
or
(
image_aspect_ratio
is
not
None
and
"anyres_max"
in
image_aspect_ratio
):
pixel_values
=
process_anyres_image
(
image
,
image_processor
,
image_grid_pinpoints
)
else
:
pixel_values
=
image_processor
(
image
)[
"pixel_values"
][
0
]
return
pixel_values
,
image_hash
,
image
.
size
except
Exception
:
logger
.
error
(
"Exception in TokenizerManager:
\n
"
+
get_exception_traceback
())
async
def
_process_single_image
(
self
,
image_data
:
Union
[
bytes
,
str
],
aspect_ratio
:
str
,
grid_pinpoints
:
str
):
if
hasattr
(
self
,
"cpu_executor"
):
executor
=
self
.
cpu_executor
else
:
executor
=
self
.
executor
if
get_global_processor
is
not
None
:
image_processor
=
None
# save ipc cost
else
:
image_processor
=
self
.
_processor
.
image_processor
if
executor
is
not
None
:
loop
=
asyncio
.
get_running_loop
()
return
await
loop
.
run_in_executor
(
executor
,
Mineru2ImageProcessor
.
_process_single_image_task
,
image_data
,
aspect_ratio
,
grid_pinpoints
,
image_processor
,
)
else
:
return
self
.
_process_single_image_task
(
image_data
,
aspect_ratio
,
grid_pinpoints
,
image_processor
,
)
# sglang==0.4.4.post1
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
,
*
args
,
**
kwargs
,
):
if
not
image_data
:
return
None
modalities
=
request_obj
.
modalities
or
[
"image"
]
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
""
)
grid_pinpoints
=
(
self
.
hf_config
.
image_grid_pinpoints
if
hasattr
(
self
.
hf_config
,
"image_grid_pinpoints"
)
and
"anyres"
in
aspect_ratio
else
None
)
if
isinstance
(
image_data
,
str
):
image_data
=
[
image_data
]
if
isinstance
(
image_data
,
list
)
and
len
(
image_data
)
>
0
:
if
"multi-images"
in
modalities
or
"video"
in
modalities
:
# Multiple images
aspect_ratio
=
"pad"
# LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values
,
image_hashes
,
image_sizes
=
[],
[],
[]
res
=
[]
for
img_data
in
image_data
:
res
.
append
(
self
.
_process_single_image
(
img_data
,
aspect_ratio
,
grid_pinpoints
))
res
=
await
asyncio
.
gather
(
*
res
)
for
pixel_v
,
image_h
,
image_s
in
res
:
pixel_values
.
append
(
pixel_v
)
image_hashes
.
append
(
image_h
)
image_sizes
.
append
(
image_s
)
if
isinstance
(
pixel_values
[
0
],
np
.
ndarray
):
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
else
:
# A single image
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
image_data
[
0
],
aspect_ratio
,
grid_pinpoints
)
image_hashes
=
[
image_hash
]
image_sizes
=
[
image_size
]
else
:
raise
ValueError
(
f
"Invalid image data:
{
image_data
}
"
)
return
{
"pixel_values"
:
pixel_values
,
"image_hashes"
:
image_hashes
,
"image_sizes"
:
image_sizes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
}
# sglang==0.4.5.post3
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
,
*
args
,
**
kwargs
,
):
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
result
=
await
self
.
process_images_async
(
image_data
,
input_text
,
request_obj
,
*
args
,
**
kwargs
)
if
result
is
None
:
return
None
modality
=
Modality
.
IMAGE
if
isinstance
(
request_obj
.
modalities
,
list
):
if
request_obj
.
modalities
[
0
]
==
"multi-images"
:
modality
=
Modality
.
MULTI_IMAGES
elif
request_obj
.
modalities
[
0
]
==
"video"
:
modality
=
Modality
.
VIDEO
return
{
"mm_items"
:
[
MultimodalDataItem
(
pixel_values
=
result
[
"pixel_values"
],
image_sizes
=
result
[
"image_sizes"
],
modality
=
modality
,
)
],
}
ImageProcessorMapping
=
{
Mineru2QwenForCausalLM
:
Mineru2ImageProcessor
}
mineru/model/vlm_sglang_model/logit_processor.py
0 → 100644
View file @
8e55a526
from
typing
import
List
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
class
Mineru2LogitProcessor
(
CustomLogitProcessor
):
"""
Stateless logit processor for Mineru2.
(base-class: sglang.srt.sampling.custom_logit_processor.CustomLogitProcessor)
This processor applies token-level constraints to prevent repetition during generation.
It supports two main constraints:
- no_repeat_ngram_size (int):
Prevents repeating the same n-gram of specified size in the output.
Inspired by Hugging Face's NoRepeatNGramLogitsProcessor.
This implementation is slower due to its lack of specialized optimization.
- no_repeat_token_count (int):
(Placeholder for future logic)
Intended to prevent repeating the same token multiple times.
Not yet implemented in this version.
"""
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
_generated_ngrams
=
{}
# Cache of generated n-grams by request ID
self
.
_time
=
{}
# Timestamp of the last update for each request
self
.
_gen_step
=
0
# Global generation step counter
def
__call__
(
self
,
logits
,
batch_info
:
List
[
dict
]):
"""
Applies repetition constraints to the logits before sampling tokens.
Args:
logits (FloatTensor): A tensor of shape (batch_size, vocab_size) containing raw token logits.
batch_info (List[dict]): A list of metadata dicts for each sample in the batch. Each dict must include:
- "__req__": Request object containing request ID and output_ids.
- "no_repeat_ngram_size": Size of n-gram to avoid repeating.
Returns:
FloatTensor: The modified logits tensor with banned token logits set to -inf.
"""
from
sglang.srt.managers.schedule_batch
import
Req
self
.
_gen_step
+=
1
# Update global generation step
for
idx
,
info
in
enumerate
(
batch_info
):
if
not
isinstance
(
info
,
dict
)
or
"__req__"
not
in
info
:
continue
req
:
Req
=
info
[
"__req__"
]
rid
=
req
.
rid
output_ids
=
req
.
output_ids
ngram_size
=
info
.
get
(
"no_repeat_ngram_size"
,
0
)
# Skip if there are not enough tokens to form an n-gram
if
ngram_size
<=
0
or
len
(
output_ids
)
<
ngram_size
:
continue
# Record the current step for cache cleanup tracking
self
.
_time
[
rid
]
=
self
.
_gen_step
# Initialize n-gram cache for this request if it doesn't exist
if
rid
not
in
self
.
_generated_ngrams
:
self
.
_generated_ngrams
[
rid
]
=
{}
# Get the n-gram prefix (all but the last token)
prev_ngram
=
tuple
(
output_ids
[
-
ngram_size
:
-
1
])
last_token
=
output_ids
[
-
1
]
# Store this n-gram occurrence
self
.
_generated_ngrams
[
rid
][
prev_ngram
]
=
self
.
_generated_ngrams
[
rid
].
get
(
prev_ngram
,
[])
+
[
last_token
]
# Get the next-token candidates to ban based on current prefix
current_prefix
=
tuple
(
output_ids
[
-
ngram_size
+
1
:])
banned_tokens
=
self
.
_generated_ngrams
[
rid
].
get
(
current_prefix
,
[])
# Set the logits of banned tokens to negative infinity
for
token
in
banned_tokens
:
logits
[
idx
][
token
]
=
-
float
(
"inf"
)
# Clean up cache for expired requests
expired_rids
=
[
rid
for
rid
,
last_used
in
self
.
_time
.
items
()
if
last_used
<
self
.
_gen_step
]
for
rid
in
expired_rids
:
self
.
_generated_ngrams
.
pop
(
rid
,
None
)
self
.
_time
.
pop
(
rid
,
None
)
return
logits
mineru/model/vlm_sglang_model/model.py
0 → 100644
View file @
8e55a526
import
math
import
re
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.mm_utils
import
(
get_anyres_image_grid_shape
,
# unpad_image, unpad_image_shape
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
from
sglang.srt.utils
import
add_prefix
from
torch
import
nn
from
transformers
import
(
CLIPVisionConfig
,
CLIPVisionModel
,
SiglipVisionConfig
,
SiglipVisionModel
,
)
from
..vlm_hf_model.configuration_mineru2
import
Mineru2QwenConfig
from
..vlm_hf_model.modeling_mineru2
import
build_vision_projector
def
flatten_nested_list
(
nested_list
):
if
isinstance
(
nested_list
,
list
):
return
[
item
for
sublist
in
nested_list
for
item
in
flatten_nested_list
(
sublist
)]
else
:
return
[
nested_list
]
def
downgrade_modality
(
modality
):
modality_str
=
str
(
modality
)
if
"MULTI_IMAGES"
in
modality_str
:
return
"multi-images"
if
"IMAGE"
in
modality_str
:
return
"image"
if
"VIDEO"
in
modality_str
:
return
"video"
if
"AUDIO"
in
modality_str
:
return
"audio"
raise
ValueError
(
f
"Unexpected modality:
{
modality_str
}
"
)
class
Mineru2QwenForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Mineru2QwenConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
if
getattr
(
self
.
config
,
"projector_hidden_act"
,
None
)
is
None
:
self
.
config
.
projector_hidden_act
=
"gelu"
if
getattr
(
self
.
config
,
"image_token_index"
,
None
)
is
None
:
self
.
config
.
image_token_index
=
151646
# load vision tower
mm_vision_tower
=
self
.
config
.
mm_vision_tower
if
"clip"
in
mm_vision_tower
:
vision_config
=
CLIPVisionConfig
.
from_pretrained
(
mm_vision_tower
)
self
.
vision_tower
=
CLIPVisionModel
(
vision_config
)
# type: ignore
elif
"siglip"
in
mm_vision_tower
:
vision_config
=
SiglipVisionConfig
.
from_pretrained
(
mm_vision_tower
)
self
.
vision_tower
=
SiglipVisionModel
(
vision_config
)
# type: ignore
# Siglip needs all feature tokens
self
.
config
.
mm_vision_select_feature
=
"full"
else
:
raise
ValueError
(
f
"Unexpected mm_vision_tower:
{
mm_vision_tower
}
"
)
### EDIT: change projector
# the name `projector` contains `proj` which is often used in attention layers, which can cause bugs in quantization.
self
.
multi_modal_mlp
=
build_vision_projector
(
config
)
self
.
language_model
=
Qwen2ForCausalLM
(
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"language_model"
,
prefix
),
)
if
"unpad"
in
getattr
(
config
,
"mm_patch_merge_type"
,
""
):
self
.
language_model
.
model
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
hidden_size
))
language_model_device
=
next
(
self
.
language_model
.
parameters
()).
device
self
.
vision_tower
=
self
.
vision_tower
.
to
(
language_model_device
)
self
.
vision_tower
.
eval
()
self
.
vision_feature_layer
=
self
.
config
.
mm_vision_select_layer
self
.
vision_feature_select_strategy
=
self
.
config
.
mm_vision_select_feature
self
.
image_size
=
self
.
vision_tower
.
config
.
image_size
self
.
patch_size
=
self
.
vision_tower
.
config
.
patch_size
self
.
mm_patch_merge_type
=
getattr
(
self
.
config
,
"mm_patch_merge_type"
,
"flat"
)
self
.
image_aspect_ratio
=
getattr
(
self
.
config
,
"image_aspect_ratio"
,
"square"
)
self
.
image_grid_pinpoints
=
getattr
(
self
.
config
,
"image_grid_pinpoints"
,
None
)
self
.
image_feature_len
=
int
((
self
.
image_size
//
self
.
patch_size
)
**
2
)
if
self
.
vision_feature_select_strategy
in
(
"patch"
,
"full"
):
pass
elif
self
.
vision_feature_select_strategy
==
"cls_patch"
:
self
.
image_feature_len
+=
1
else
:
raise
ValueError
(
f
"Unexpected select feature:
{
self
.
select_feature
}
"
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
):
if
hasattr
(
image_inputs
,
"mm_items"
):
# MultimodalInputs
# sglang==0.4.5.post3
image_sizes
=
flatten_nested_list
([
item
.
image_sizes
for
item
in
image_inputs
.
mm_items
])
pad_values
=
[
item
.
pad_value
for
item
in
image_inputs
.
mm_items
]
else
:
# ImageInputs
# sglang==0.4.4.post1
image_sizes
=
image_inputs
.
image_sizes
pad_values
=
image_inputs
.
pad_values
# hardcode for spatial_unpad + anyres
# if image_inputs.modalities is not None and (
# "multi-images" in image_inputs.modalities or "video" in image_inputs.modalities
# ):
# image_aspect_ratio = "pad"
# else:
# image_aspect_ratio = "anyres"
offset_list
=
[]
image_inputs
.
image_pad_len
=
[]
for
image_idx
,
image_s
in
enumerate
(
image_sizes
):
if
len
(
image_sizes
)
>
16
:
# 2x2 pooling with stride 2
new_image_feature_len
=
math
.
ceil
(
self
.
image_size
/
self
.
patch_size
/
2
)
**
2
else
:
new_image_feature_len
=
self
.
image_feature_len
# multiimage
height
=
width
=
self
.
num_patches_per_side
if
"anyres"
in
self
.
config
.
image_aspect_ratio
:
num_patch_width
,
num_patch_height
=
get_anyres_image_grid_shape
(
image_s
,
self
.
image_grid_pinpoints
,
self
.
vision_tower
.
config
.
image_size
,
)
h
=
num_patch_height
*
height
w
=
num_patch_width
*
width
### EDIT: remove `unpad_image_shape`
# new_h, new_w = unpad_image_shape(h, w, image_s)
new_h
,
new_w
=
h
,
w
if
"anyres_max"
in
self
.
config
.
image_aspect_ratio
:
matched_anyres_max_num_patches
=
re
.
match
(
r
".*anyres_max_(\d+)"
,
self
.
config
.
image_aspect_ratio
)
if
matched_anyres_max_num_patches
:
max_num_patches
=
int
(
matched_anyres_max_num_patches
.
group
(
1
))
times
=
math
.
sqrt
(
new_h
*
new_w
/
(
max_num_patches
*
self
.
image_feature_len
))
if
times
>
1.1
:
new_h
=
int
(
new_h
//
times
)
new_w
=
int
(
new_w
//
times
)
new_image_feature_len
+=
new_h
*
(
new_w
+
1
)
try
:
offset
=
input_ids
.
index
(
self
.
config
.
image_token_index
)
except
ValueError
:
offset
=
0
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids
=
input_ids
[:
offset
]
+
[
pad_values
[
image_idx
]]
*
new_image_feature_len
+
input_ids
[
offset
+
1
:]
offset_list
.
append
(
offset
)
image_inputs
.
image_pad_len
.
append
(
new_image_feature_len
)
image_inputs
.
image_offsets
=
offset_list
return
input_ids
def
encode_images
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
pixel_values
=
pixel_values
.
to
(
device
=
self
.
vision_tower
.
device
,
dtype
=
self
.
vision_tower
.
dtype
)
image_outputs
=
self
.
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature
=
image_outputs
.
hidden_states
[
self
.
vision_feature_layer
]
if
self
.
vision_feature_select_strategy
in
[
"default"
,
"patch"
]:
selected_image_feature
=
selected_image_feature
[:,
1
:]
elif
self
.
vision_feature_select_strategy
==
"full"
:
selected_image_feature
=
selected_image_feature
else
:
raise
ValueError
(
f
"Unexpected select feature strategy:
{
self
.
vision_feature_select_strategy
}
"
)
image_features
=
self
.
multi_modal_mlp
(
selected_image_feature
)
return
image_features
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
if
hasattr
(
forward_batch
,
"mm_inputs"
):
# sglang==0.4.5.post3
image_inputs
=
forward_batch
.
mm_inputs
is_sglang_mm_inputs
=
True
else
:
# sglang==0.4.4.post1
image_inputs
=
forward_batch
.
image_inputs
is_sglang_mm_inputs
=
False
if
image_inputs
is
None
:
image_inputs
=
[]
if
forward_batch
.
forward_mode
.
is_extend
():
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
# Embed text inputs
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list
=
[]
max_image_offset
=
[]
for
im
in
image_inputs
:
if
im
:
if
hasattr
(
im
,
"mm_items"
):
# sglang==0.4.5.post3
modalities_list
.
extend
([
downgrade_modality
(
item
.
modality
)
for
item
in
im
.
mm_items
])
elif
im
.
modalities
is
not
None
:
# sglang==0.4.4.post1
modalities_list
.
extend
(
im
.
modalities
)
if
im
and
im
.
image_offsets
:
max_image_offset
.
append
(
np
.
max
(
np
.
array
(
im
.
image_offsets
)
+
np
.
array
(
im
.
image_pad_len
)))
else
:
max_image_offset
.
append
(
-
1
)
start_positions
=
positions
[
forward_batch
.
extend_start_loc
].
cpu
().
numpy
()
need_vision
=
start_positions
<=
np
.
array
(
max_image_offset
)
if
need_vision
.
any
():
bs
=
forward_batch
.
batch_size
if
is_sglang_mm_inputs
:
# sglang==0.4.5.post3
pixel_values
=
flatten_nested_list
(
[[
item
.
pixel_values
for
item
in
image_inputs
[
i
].
mm_items
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
)
# image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
image_sizes
=
[
flatten_nested_list
([
item
.
image_sizes
for
item
in
image_inputs
[
i
].
mm_items
])
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
# image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
else
:
# sglang==0.4.4.post1
pixel_values
=
[
image_inputs
[
i
].
pixel_values
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
image_sizes
=
[
image_inputs
[
i
].
image_sizes
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
########## Encode Image ########
if
pixel_values
[
0
].
ndim
==
4
:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
np
.
concatenate
(
pixel_values
,
axis
=
0
)
# ndim=4
concat_images
=
torch
.
tensor
(
np
.
concatenate
(
pixel_values
,
axis
=
0
),
device
=
self
.
vision_tower
.
device
,
)
image_features
=
self
.
encode_images
(
concat_images
)
split_sizes
=
[
image
.
shape
[
0
]
for
image
in
pixel_values
]
image_features
=
torch
.
split
(
image_features
,
split_sizes
,
dim
=
0
)
# hd image_features: BS, num_patch, 576, 4096
else
:
# normal pixel: BS, C=3, H=336, W=336
pixel_values
=
torch
.
tensor
(
np
.
array
(
pixel_values
),
device
=
self
.
vision_tower
.
device
)
image_features
=
self
.
encode_images
(
pixel_values
)
# image_features: BS, 576, 4096
if
self
.
mm_patch_merge_type
.
startswith
(
"spatial"
):
new_image_features
=
[]
height
=
width
=
self
.
num_patches_per_side
for
image_idx
,
image_feature
in
enumerate
(
image_features
):
if
modalities_list
[
image_idx
]
==
"image"
:
image_aspect_ratio
=
self
.
config
.
image_aspect_ratio
# single image
elif
modalities_list
[
image_idx
]
==
"multi-images"
or
modalities_list
[
image_idx
]
==
"video"
:
image_aspect_ratio
=
"pad"
# multi image
# image_aspect_ratio = (
# "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
# )
if
(
image_feature
.
shape
[
0
]
>
1
and
"anyres"
in
image_aspect_ratio
and
modalities_list
[
image_idx
]
==
"image"
):
base_image_feature
=
image_feature
[
0
]
image_feature
=
image_feature
[
1
:]
assert
height
*
width
==
base_image_feature
.
shape
[
0
]
if
"anyres_max"
in
image_aspect_ratio
:
matched_anyres_max_num_patches
=
re
.
match
(
r
".*anyres_max_(\d+)"
,
image_aspect_ratio
)
if
matched_anyres_max_num_patches
:
max_num_patches
=
int
(
matched_anyres_max_num_patches
.
group
(
1
))
if
image_aspect_ratio
==
"anyres"
or
"anyres_max"
in
image_aspect_ratio
:
vision_tower_image_size
=
self
.
image_size
try
:
num_patch_width
,
num_patch_height
=
get_anyres_image_grid_shape
(
image_sizes
[
image_idx
][
0
],
self
.
config
.
image_grid_pinpoints
,
vision_tower_image_size
,
)
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
"
)
num_patch_width
,
num_patch_height
=
2
,
2
image_feature
=
image_feature
.
view
(
num_patch_height
,
num_patch_width
,
height
,
width
,
-
1
)
else
:
image_feature
=
image_feature
.
view
(
2
,
2
,
height
,
width
,
-
1
)
if
"unpad"
in
self
.
mm_patch_merge_type
:
unit
=
image_feature
.
shape
[
2
]
image_feature
=
image_feature
.
permute
(
4
,
0
,
2
,
1
,
3
).
contiguous
()
image_feature
=
image_feature
.
flatten
(
1
,
2
).
flatten
(
2
,
3
)
### EDIT: remove `unpad_image`
# image_feature = unpad_image(image_feature, image_sizes[image_idx][0])
if
"anyres_max"
in
image_aspect_ratio
and
matched_anyres_max_num_patches
:
c
,
h
,
w
=
image_feature
.
shape
times
=
math
.
sqrt
(
h
*
w
/
(
max_num_patches
*
unit
**
2
))
if
times
>
1.1
:
image_feature
=
image_feature
[
None
]
image_feature
=
nn
.
functional
.
interpolate
(
image_feature
,
[
int
(
h
//
times
),
int
(
w
//
times
)],
mode
=
"bilinear"
,
)[
0
]
image_feature
=
torch
.
cat
(
(
image_feature
,
self
.
language_model
.
model
.
image_newline
[:,
None
,
None
].
expand
(
*
image_feature
.
shape
[:
-
1
],
1
),
),
dim
=-
1
,
)
image_feature
=
image_feature
.
flatten
(
1
,
2
).
transpose
(
0
,
1
)
else
:
image_feature
=
image_feature
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
image_feature
=
image_feature
.
flatten
(
0
,
3
)
image_feature
=
torch
.
cat
((
base_image_feature
,
image_feature
),
dim
=
0
)
image_feature
=
image_feature
.
unsqueeze
(
0
)
else
:
if
modalities_list
[
image_idx
]
==
"video"
:
# video
# 2x2 pooling
num_of_frames
=
image_feature
.
shape
[
0
]
image_feature
=
image_feature
.
view
(
num_of_frames
,
height
,
width
,
-
1
)
image_feature
=
image_feature
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
# N, C, H, W
height
,
weight
=
image_feature
.
shape
[
2
:]
scaled_shape
=
[
math
.
ceil
(
height
/
2
),
math
.
ceil
(
weight
/
2
),
]
image_feature
=
nn
.
functional
.
interpolate
(
image_feature
,
size
=
scaled_shape
,
mode
=
"bilinear"
)
image_feature
=
image_feature
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
# N, C, H*W
if
"unpad"
in
self
.
mm_patch_merge_type
:
image_feature
=
torch
.
cat
(
(
image_feature
,
# Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
self
.
language_model
.
model
.
image_newline
[
None
,
None
].
expand
(
image_feature
.
shape
[
0
],
1
,
image_feature
.
shape
[
-
1
],
),
),
dim
=
1
,
)
new_image_features
.
append
(
image_feature
)
image_features
=
new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
extend_seq_lens
=
forward_batch
.
extend_seq_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
pt
=
0
for
i
in
range
(
bs
):
if
not
need_vision
[
i
]:
continue
start_idx
=
extend_start_loc_cpu
[
i
]
seq_len
=
extend_seq_lens
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
# Multiple images
for
image_idx
,
image_offset
in
enumerate
(
image_inputs
[
i
].
image_offsets
):
if
image_offset
+
image_inputs
[
i
].
image_pad_len
[
image_idx
]
<=
prefix_len
:
continue
if
image_offset
>=
prefix_len
+
seq_len
:
break
tmp_image_feature
=
image_features
[
pt
][
image_idx
]
pad_len
=
tmp_image_feature
.
shape
[
0
]
input_offset
=
image_offset
-
prefix_len
left_idx
=
start_idx
+
input_offset
right_idx
=
left_idx
+
pad_len
assert
right_idx
>
start_idx
if
input_offset
<
0
:
left_idx
=
start_idx
tmp_image_feature
=
tmp_image_feature
[
-
input_offset
:]
if
right_idx
>
start_idx
+
seq_len
:
tmp_image_feature
=
tmp_image_feature
[:
start_idx
+
seq_len
-
right_idx
]
right_idx
=
start_idx
+
seq_len
try
:
input_embeds
[
left_idx
:
right_idx
]
=
tmp_image_feature
except
RuntimeError
as
e
:
print
(
f
"RuntimeError in image encoding:
{
e
}
"
)
print
(
f
"
{
input_embeds
.
shape
=
}
,
{
tmp_image_feature
.
shape
=
}
"
)
print
(
f
"
{
start_idx
=
}
,
{
image_offset
=
}
,
{
prefix_len
=
}
,
{
pad_len
=
}
"
)
pt
+=
1
return
self
.
language_model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
=
input_embeds
)
elif
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
language_model
(
input_ids
,
positions
,
forward_batch
)
else
:
raise
ValueError
(
f
"Unexpected forward mode:
{
forward_batch
.
forward_mode
}
"
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
projector_weights
=
{
"model.mm_projector"
:
"multi_modal_mlp"
,
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.image_newline"
:
"language_model.model.image_newline"
,
}
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"projector"
in
name
or
"vision_tower"
in
name
or
"image_newline"
in
name
:
for
weight_name
,
param_name
in
projector_weights
.
items
():
if
weight_name
in
name
:
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
else
:
self
.
language_model
.
load_weights
([(
name
,
loaded_weight
)])
@
property
def
num_patches_per_side
(
self
):
return
self
.
image_size
//
self
.
patch_size
EntryClass
=
[
Mineru2QwenForCausalLM
]
mineru/model/vlm_sglang_model/server.py
0 → 100644
View file @
8e55a526
import
os
import
sys
from
fastapi
import
Request
from
sglang.srt.entrypoints.http_server
import
app
,
generate_request
,
launch_server
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.server_args
import
prepare_server_args
from
sglang.srt.utils
import
kill_process_tree
from
.logit_processor
import
Mineru2LogitProcessor
_custom_logit_processor_str
=
Mineru2LogitProcessor
().
to_str
()
# remote the existing /generate route
for
route
in
app
.
routes
[:]:
if
hasattr
(
route
,
"path"
)
and
getattr
(
route
,
"path"
)
==
"/generate"
:
app
.
routes
.
remove
(
route
)
# add the custom /generate route
@
app
.
api_route
(
"/generate"
,
methods
=
[
"POST"
,
"PUT"
])
async
def
custom_generate_request
(
obj
:
GenerateReqInput
,
request
:
Request
):
if
obj
.
custom_logit_processor
is
None
:
obj
.
custom_logit_processor
=
_custom_logit_processor_str
return
await
generate_request
(
obj
,
request
)
def
main
():
server_args
=
prepare_server_args
(
sys
.
argv
[
1
:])
if
server_args
.
chat_template
is
None
:
server_args
.
chat_template
=
"chatml"
server_args
.
enable_custom_logit_processor
=
True
try
:
launch_server
(
server_args
)
finally
:
kill_process_tree
(
os
.
getpid
(),
include_parent
=
False
)
if
__name__
==
"__main__"
:
main
()
mineru/utils/pdf_reader.py
0 → 100644
View file @
8e55a526
# Copyright (c) Opendatalab. All rights reserved.
import
base64
from
io
import
BytesIO
from
loguru
import
logger
from
PIL
import
Image
from
pypdfium2
import
PdfBitmap
,
PdfDocument
,
PdfPage
def
page_to_image
(
page
:
PdfPage
,
dpi
:
int
=
144
,
# changed from 200 to 144
max_width_or_height
:
int
=
2560
,
# changed from 4500 to 2560
)
->
(
Image
.
Image
,
float
):
scale
=
dpi
/
72
long_side_length
=
max
(
*
page
.
get_size
())
if
long_side_length
>
max_width_or_height
:
scale
=
max_width_or_height
/
long_side_length
bitmap
:
PdfBitmap
=
page
.
render
(
scale
=
scale
)
# type: ignore
try
:
image
=
bitmap
.
to_pil
()
finally
:
try
:
bitmap
.
close
()
except
Exception
:
pass
return
image
,
scale
def
image_to_bytes
(
image
:
Image
.
Image
,
image_format
:
str
=
"PNG"
,
# 也可以用 "JPEG"
)
->
bytes
:
with
BytesIO
()
as
image_buffer
:
image
.
save
(
image_buffer
,
format
=
image_format
)
return
image_buffer
.
getvalue
()
def
image_to_b64str
(
image
:
Image
.
Image
,
image_format
:
str
=
"PNG"
,
# 也可以用 "JPEG"
)
->
str
:
image_bytes
=
image_to_bytes
(
image
,
image_format
)
return
base64
.
b64encode
(
image_bytes
).
decode
(
"utf-8"
)
def
pdf_to_images
(
pdf
:
str
|
bytes
|
PdfDocument
,
dpi
:
int
=
144
,
max_width_or_height
:
int
=
2560
,
start_page_id
:
int
=
0
,
end_page_id
:
int
|
None
=
None
,
)
->
list
[
Image
.
Image
]:
doc
=
pdf
if
isinstance
(
pdf
,
PdfDocument
)
else
PdfDocument
(
pdf
)
page_num
=
len
(
doc
)
end_page_id
=
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
page_num
-
1
if
end_page_id
>
page_num
-
1
:
logger
.
warning
(
"end_page_id is out of range, use images length"
)
end_page_id
=
page_num
-
1
images
=
[]
try
:
for
i
in
range
(
start_page_id
,
end_page_id
+
1
):
image
,
_
=
page_to_image
(
doc
[
i
],
dpi
,
max_width_or_height
)
images
.
append
(
image
)
finally
:
try
:
doc
.
close
()
except
Exception
:
pass
return
images
def
pdf_to_images_bytes
(
pdf
:
str
|
bytes
|
PdfDocument
,
dpi
:
int
=
144
,
max_width_or_height
:
int
=
2560
,
start_page_id
:
int
=
0
,
end_page_id
:
int
|
None
=
None
,
image_format
:
str
=
"PNG"
,
)
->
list
[
bytes
]:
images
=
pdf_to_images
(
pdf
,
dpi
,
max_width_or_height
,
start_page_id
,
end_page_id
)
return
[
image_to_bytes
(
image
,
image_format
)
for
image
in
images
]
def
pdf_to_images_b64strs
(
pdf
:
str
|
bytes
|
PdfDocument
,
dpi
:
int
=
144
,
max_width_or_height
:
int
=
2560
,
start_page_id
:
int
=
0
,
end_page_id
:
int
|
None
=
None
,
image_format
:
str
=
"PNG"
,
)
->
list
[
str
]:
images
=
pdf_to_images
(
pdf
,
dpi
,
max_width_or_height
,
start_page_id
,
end_page_id
)
return
[
image_to_b64str
(
image
,
image_format
)
for
image
in
images
]
mineru/utils/run_async.py
0 → 100644
View file @
8e55a526
import
asyncio
import
threading
from
queue
import
Queue
from
typing
import
Any
,
AsyncIterable
,
Coroutine
,
Iterable
,
TypeVar
T
=
TypeVar
(
"T"
)
def
run_async
(
coroutine
:
Coroutine
[
Any
,
Any
,
T
])
->
T
:
if
not
asyncio
.
iscoroutine
(
coroutine
):
raise
ValueError
(
"a coroutine was expected, got {!r}"
.
format
(
coroutine
))
try
:
loop
=
asyncio
.
get_running_loop
()
except
RuntimeError
:
loop
=
None
if
loop
is
not
None
:
return
loop
.
run_until_complete
(
coroutine
)
else
:
return
asyncio
.
run
(
coroutine
)
def
iter_async
(
iterable
:
AsyncIterable
[
T
])
->
Iterable
[
T
]:
if
not
isinstance
(
iterable
,
AsyncIterable
):
raise
ValueError
(
"an async iterable was expected, got {!r}"
.
format
(
iterable
))
queue
=
Queue
()
async
def
async_helper
():
try
:
async
for
chunk
in
iterable
:
queue
.
put
(
chunk
)
queue
.
put
(
None
)
except
Exception
as
e
:
queue
.
put
(
e
)
def
helper
():
run_async
(
async_helper
())
thread
=
threading
.
Thread
(
target
=
helper
,
daemon
=
True
)
thread
.
start
()
while
True
:
chunk
=
queue
.
get
()
if
chunk
is
None
:
break
if
isinstance
(
chunk
,
Exception
):
raise
chunk
yield
chunk
thread
.
join
()
pyproject.toml
0 → 100644
View file @
8e55a526
[tool.black]
line-length
=
128
[tool.ruff]
line-length
=
128
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment