Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ff2ce0b8
Unverified
Commit
ff2ce0b8
authored
Mar 12, 2025
by
Mick
Committed by
GitHub
Mar 11, 2025
Browse files
refactor: move image processors to separate files (#4229)
parent
0f2a2e3c
Changes
22
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1069 additions
and
945 deletions
+1069
-945
benchmark/mmmu/bench_hf.py
benchmark/mmmu/bench_hf.py
+18
-21
benchmark/mmmu/bench_sglang.py
benchmark/mmmu/bench_sglang.py
+49
-33
benchmark/mmmu/data_utils.py
benchmark/mmmu/data_utils.py
+1
-0
benchmark/mmmu/eval_utils.py
benchmark/mmmu/eval_utils.py
+22
-3
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+1
-2
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+1
-0
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+42
-61
python/sglang/srt/managers/image_processor.py
python/sglang/srt/managers/image_processor.py
+37
-631
python/sglang/srt/managers/image_processors/base_image_processor.py
...ang/srt/managers/image_processors/base_image_processor.py
+206
-0
python/sglang/srt/managers/image_processors/llava.py
python/sglang/srt/managers/image_processors/llava.py
+152
-0
python/sglang/srt/managers/image_processors/minicpmv.py
python/sglang/srt/managers/image_processors/minicpmv.py
+86
-0
python/sglang/srt/managers/image_processors/mlama.py
python/sglang/srt/managers/image_processors/mlama.py
+60
-0
python/sglang/srt/managers/image_processors/qwen_vl.py
python/sglang/srt/managers/image_processors/qwen_vl.py
+161
-0
python/sglang/srt/managers/multi_modality_padding.py
python/sglang/srt/managers/multi_modality_padding.py
+134
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+11
-5
python/sglang/srt/model_loader/weight_utils.py
python/sglang/srt/model_loader/weight_utils.py
+1
-1
python/sglang/srt/models/minicpmv.py
python/sglang/srt/models/minicpmv.py
+28
-89
python/sglang/srt/models/mllama.py
python/sglang/srt/models/mllama.py
+1
-1
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+25
-49
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+33
-49
No files found.
benchmark/mmmu/bench_hf.py
View file @
ff2ce0b8
...
@@ -11,11 +11,16 @@ import argparse
...
@@ -11,11 +11,16 @@ import argparse
import
random
import
random
import
torch
import
torch
from
bench_sglang
import
EvalArgs
,
prepare_samples
from
data_utils
import
save_json
from
data_utils
import
save_json
from
eval_utils
import
eval_result
,
get_sampling_params
,
parse_multi_choice_response
from
eval_utils
import
(
EvalArgs
,
eval_result
,
get_sampling_params
,
prepare_samples
,
process_result
,
)
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
AutoModelForImageTextToText
,
AutoProcessor
from
transformers
import
AutoModelForImageTextToText
,
AutoProcessor
,
GenerationConfig
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -28,7 +33,6 @@ def eval_mmmu(args):
...
@@ -28,7 +33,6 @@ def eval_mmmu(args):
trust_remote_code
=
True
,
trust_remote_code
=
True
,
)
)
model
=
model
.
eval
().
cuda
()
model
=
model
.
eval
().
cuda
()
model
=
torch
.
compile
(
model
)
processor
=
AutoProcessor
.
from_pretrained
(
processor
=
AutoProcessor
.
from_pretrained
(
args
.
model_path
,
torch_dtype
=
"auto"
,
device_map
=
"auto"
args
.
model_path
,
torch_dtype
=
"auto"
,
device_map
=
"auto"
...
@@ -38,6 +42,10 @@ def eval_mmmu(args):
...
@@ -38,6 +42,10 @@ def eval_mmmu(args):
out_samples
=
dict
()
out_samples
=
dict
()
sampling_params
=
get_sampling_params
(
eval_args
)
sampling_params
=
get_sampling_params
(
eval_args
)
generation_config
=
GenerationConfig
(
max_new_tokens
=
sampling_params
[
"max_new_tokens"
],
do_sample
=
False
,
)
answer_dict
=
{}
answer_dict
=
{}
for
sample
in
tqdm
(
samples
):
for
sample
in
tqdm
(
samples
):
...
@@ -62,7 +70,6 @@ def eval_mmmu(args):
...
@@ -62,7 +70,6 @@ def eval_mmmu(args):
text
=
processor
.
apply_chat_template
(
text
=
processor
.
apply_chat_template
(
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
)
inputs
=
processor
(
inputs
=
processor
(
text
=
[
text
],
text
=
[
text
],
images
=
[
image
],
images
=
[
image
],
...
@@ -70,13 +77,16 @@ def eval_mmmu(args):
...
@@ -70,13 +77,16 @@ def eval_mmmu(args):
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
).
to
(
model
.
device
)
).
to
(
model
.
device
)
generated_ids
=
model
.
generate
(
**
inputs
,
**
sampling_params
)
generated_ids
=
model
.
generate
(
**
inputs
,
generation_config
=
generation_config
)
response
=
processor
.
decode
(
response
=
processor
.
decode
(
generated_ids
[
0
],
generated_ids
[
0
],
skip_special_tokens
=
True
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
,
clean_up_tokenization_spaces
=
False
,
)[
len
(
text
)
:]
)[
len
(
text
)
:]
print
(
f
"response:
{
response
}
"
)
else
:
# multiple images actually
else
:
# multiple images actually
if
sample
[
"question_type"
]
==
"multiple-choice"
:
if
sample
[
"question_type"
]
==
"multiple-choice"
:
all_choices
=
sample
[
"all_choices"
]
all_choices
=
sample
[
"all_choices"
]
...
@@ -85,24 +95,11 @@ def eval_mmmu(args):
...
@@ -85,24 +95,11 @@ def eval_mmmu(args):
else
:
else
:
response
=
"INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
response
=
"INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
if
sample
[
"question_type"
]
==
"multiple-choice"
:
process_result
(
response
,
sample
,
answer_dict
,
out_samples
)
pred_ans
=
parse_multi_choice_response
(
response
,
sample
[
"all_choices"
],
sample
[
"index2ans"
]
)
else
:
# open question
pred_ans
=
response
out_samples
[
sample
[
"id"
]]
=
pred_ans
torch
.
cuda
.
empty_cache
()
# set ground truth answer
answer_dict
[
sample
[
"id"
]]
=
{
"question_type"
:
sample
[
"question_type"
],
"ground_truth"
:
sample
[
"answer"
],
}
args
.
output_path
=
f
"
{
args
.
model_path
}
_val_hf.json"
args
.
output_path
=
f
"
{
args
.
model_path
}
_val_hf.json"
save_json
(
args
.
output_path
,
out_samples
)
save_json
(
args
.
output_path
,
out_samples
)
eval_result
(
output
_path
=
args
.
output_path
,
answer_dict
=
answer_dict
)
eval_result
(
model_answer
_path
=
args
.
output_path
,
answer_dict
=
answer_dict
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
benchmark/mmmu/bench_sglang.py
View file @
ff2ce0b8
...
@@ -8,9 +8,9 @@
...
@@ -8,9 +8,9 @@
"""
"""
import
argparse
import
argparse
import
base64
import
dataclasses
import
dataclasses
import
random
import
random
import
re
from
io
import
BytesIO
from
io
import
BytesIO
from
data_utils
import
save_json
from
data_utils
import
save_json
...
@@ -18,13 +18,14 @@ from eval_utils import (
...
@@ -18,13 +18,14 @@ from eval_utils import (
EvalArgs
,
EvalArgs
,
eval_result
,
eval_result
,
get_sampling_params
,
get_sampling_params
,
parse_multi_choice_response
,
prepare_samples
,
prepare_samples
,
process_result
,
)
)
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
sglang
import
Engine
from
sglang
import
Engine
from
sglang.srt.conversation
import
chat_templates
from
sglang.srt.conversation
import
generate_chat_conv
from
sglang.srt.openai_api.protocol
import
ChatCompletionRequest
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -35,61 +36,76 @@ def eval_mmmu(args):
...
@@ -35,61 +36,76 @@ def eval_mmmu(args):
if
server_args
.
chat_template
is
None
:
if
server_args
.
chat_template
is
None
:
raise
ValueError
(
"Chat template must be provided for this benchmark"
)
raise
ValueError
(
"Chat template must be provided for this benchmark"
)
samples
=
prepare_samples
(
eval_args
)
backend
=
Engine
(
**
dataclasses
.
asdict
(
server_args
))
backend
=
Engine
(
**
dataclasses
.
asdict
(
server_args
))
out_samples
=
dict
()
out_samples
=
dict
()
sampling_params
=
get_sampling_params
(
eval_args
)
sampling_params
=
get_sampling_params
(
eval_args
)
conv
=
chat_templates
[
server_args
.
chat_template
].
copy
(
)
samples
=
prepare_samples
(
eval_args
)
image_token
=
conv
.
image_token
answer_dict
=
{}
answer_dict
=
{}
for
sample
in
tqdm
(
samples
):
for
sample
in
tqdm
(
samples
):
prompt
=
sample
[
"final_input_prompt"
]
prompt
=
sample
[
"final_input_prompt"
]
image
=
sample
[
"image"
]
image
=
sample
[
"image"
]
bytes_io
=
BytesIO
()
buff
=
BytesIO
()
image
.
save
(
bytes_io
,
format
=
"PNG"
)
image
.
save
(
buff
,
format
=
"PNG"
)
png_bytes
=
bytes_io
.
getvalue
()
base64_str
=
base64
.
b64encode
(
buff
.
getvalue
()).
decode
(
"utf-8"
)
prefix
=
prompt
.
split
(
"<"
)[
0
]
prompt
=
re
.
sub
(
r
"<[^>]*>"
,
image_token
,
prompt
)
suffix
=
prompt
.
split
(
">"
)[
1
]
request_dict
=
{
"model"
:
""
,
"messages"
:
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
prefix
,
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
f
"data:image/jpeg;base64,
{
base64_str
}
"
},
},
{
"type"
:
"text"
,
"text"
:
suffix
,
},
],
}
],
}
conv
=
generate_chat_conv
(
ChatCompletionRequest
(
**
request_dict
),
template_name
=
server_args
.
chat_template
,
)
prompt
=
conv
.
get_prompt
()
if
image
is
not
None
:
if
image
is
not
None
:
gen_out
=
backend
.
generate
(
gen_out
=
backend
.
generate
(
prompt
=
prompt
,
image_data
=
[
png_bytes
],
sampling_params
=
sampling_params
prompt
=
prompt
,
image_data
=
conv
.
image_data
,
sampling_params
=
sampling_params
,
)[
"text"
]
)[
"text"
]
response
=
gen_out
response
=
gen_out
else
:
# multiple images actually
else
:
# multiple images actually
if
sample
[
"question_type"
]
==
"multiple-choice"
:
if
sample
[
"question_type"
]
==
"multiple-choice"
:
all_choices
=
sample
[
"all_choices"
]
all_choices
=
sample
[
"all_choices"
]
response
=
random
.
choice
(
all_choices
)
response
=
random
.
choice
(
all_choices
)
else
:
else
:
response
=
"INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
response
=
"INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
if
sample
[
"question_type"
]
==
"multiple-choice"
:
process_result
(
response
,
sample
,
answer_dict
,
out_samples
)
pred_ans
=
parse_multi_choice_response
(
response
,
sample
[
"all_choices"
],
sample
[
"index2ans"
]
)
else
:
# open question
pred_ans
=
response
out_samples
[
sample
[
"id"
]]
=
pred_ans
# set ground truth answer
answer_dict
[
sample
[
"id"
]]
=
{
"question_type"
:
sample
[
"question_type"
],
"ground_truth"
:
(
sample
[
"correct_choice"
]
if
"correct_choice"
in
samples
else
sample
[
"answer"
]
),
}
args
.
output_path
=
f
"
{
args
.
model_path
}
_val_sglang.json"
args
.
output_path
=
f
"
{
args
.
model_path
}
_val_sglang.json"
save_json
(
args
.
output_path
,
out_samples
)
save_json
(
args
.
output_path
,
out_samples
)
eval_result
(
output_path
=
args
.
output_path
,
answer_dict
=
answer_dict
)
eval_result
(
model_answer_path
=
args
.
output_path
,
answer_dict
=
answer_dict
)
backend
.
shutdown
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
benchmark/mmmu/data_utils.py
View file @
ff2ce0b8
...
@@ -143,6 +143,7 @@ def process_single_sample(data):
...
@@ -143,6 +143,7 @@ def process_single_sample(data):
# DATA SAVING
# DATA SAVING
def
save_json
(
filename
,
ds
):
def
save_json
(
filename
,
ds
):
print
(
f
"answers saved to:
{
filename
}
"
)
os
.
makedirs
(
os
.
path
.
dirname
(
filename
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
dirname
(
filename
),
exist_ok
=
True
)
with
open
(
filename
,
"w"
)
as
f
:
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
(
ds
,
f
,
indent
=
4
)
json
.
dump
(
ds
,
f
,
indent
=
4
)
...
...
benchmark/mmmu/eval_utils.py
View file @
ff2ce0b8
...
@@ -87,6 +87,7 @@ def set_seed(seed_value):
...
@@ -87,6 +87,7 @@ def set_seed(seed_value):
def
prepare_samples
(
eval_args
:
EvalArgs
):
def
prepare_samples
(
eval_args
:
EvalArgs
):
print
(
"preparing samples..."
)
# Build prompts
# Build prompts
set_seed
(
eval_args
.
seed
)
set_seed
(
eval_args
.
seed
)
...
@@ -110,6 +111,7 @@ def prepare_samples(eval_args: EvalArgs):
...
@@ -110,6 +111,7 @@ def prepare_samples(eval_args: EvalArgs):
eval_args
.
dataset_path
,
subject
,
split
=
eval_args
.
split
eval_args
.
dataset_path
,
subject
,
split
=
eval_args
.
split
)
)
sub_dataset_list
.
append
(
sub_dataset
)
sub_dataset_list
.
append
(
sub_dataset
)
# break
# merge all dataset
# merge all dataset
dataset
=
concatenate_datasets
(
sub_dataset_list
)
dataset
=
concatenate_datasets
(
sub_dataset_list
)
...
@@ -426,9 +428,26 @@ def calculate_ins_level_acc(results: Dict):
...
@@ -426,9 +428,26 @@ def calculate_ins_level_acc(results: Dict):
return
acc
/
ins_num
return
acc
/
ins_num
def
eval_result
(
output_path
,
answer_dict
):
def
process_result
(
response
,
sample
,
answer_dict
,
out_samples
):
if
sample
[
"question_type"
]
==
"multiple-choice"
:
pred_ans
=
parse_multi_choice_response
(
response
,
sample
[
"all_choices"
],
sample
[
"index2ans"
]
)
else
:
# open question
pred_ans
=
response
out_samples
[
sample
[
"id"
]]
=
pred_ans
# set ground truth answer
answer_dict
[
sample
[
"id"
]]
=
{
"question_type"
:
sample
[
"question_type"
],
"ground_truth"
:
sample
[
"answer"
],
}
def
eval_result
(
model_answer_path
,
answer_dict
):
print
(
"Evaluating..."
)
print
(
"Evaluating..."
)
output_dict
=
json
.
load
(
open
(
output
_path
))
output_dict
=
json
.
load
(
open
(
model_answer
_path
))
# answer_dict = json.load(open(answer_path))
# answer_dict = json.load(open(answer_path))
# group by category
# group by category
...
@@ -521,7 +540,7 @@ def eval_result(output_path, answer_dict):
...
@@ -521,7 +540,7 @@ def eval_result(output_path, answer_dict):
"acc"
:
overall_acc
,
"acc"
:
overall_acc
,
}
}
pprint
.
pprint
(
printable_results
)
pprint
.
pprint
(
printable_results
)
out
=
output
_path
out
=
model_answer
_path
with
open
(
out
,
"w"
,
encoding
=
"utf-8"
)
as
outfile
:
with
open
(
out
,
"w"
,
encoding
=
"utf-8"
)
as
outfile
:
json
.
dump
(
printable_results
,
outfile
)
json
.
dump
(
printable_results
,
outfile
)
print
(
f
"eval out saved to
{
out
}
"
)
print
(
f
"eval out saved to
{
out
}
"
)
...
...
python/sglang/srt/conversation.py
View file @
ff2ce0b8
...
@@ -191,7 +191,7 @@ class Conversation:
...
@@ -191,7 +191,7 @@ class Conversation:
for
i
,
(
role
,
message
)
in
enumerate
(
self
.
messages
):
for
i
,
(
role
,
message
)
in
enumerate
(
self
.
messages
):
if
i
%
2
==
0
:
if
i
%
2
==
0
:
ret
+=
f
"[Round
{
i
//
2
+
round_add_n
}
]
{
self
.
sep
}
"
ret
+=
f
"[Round
{
i
//
2
+
round_add_n
}
]
{
self
.
sep
}
"
if
message
:
if
message
:
ret
+=
f
"
{
role
}
:
{
message
}{
self
.
sep
}
"
ret
+=
f
"
{
role
}
:
{
message
}{
self
.
sep
}
"
...
@@ -453,7 +453,6 @@ def generate_chat_conv(
...
@@ -453,7 +453,6 @@ def generate_chat_conv(
conv
.
system_message
=
getattr
(
message
.
content
[
0
],
"text"
,
""
)
conv
.
system_message
=
getattr
(
message
.
content
[
0
],
"text"
,
""
)
elif
msg_role
==
"user"
:
elif
msg_role
==
"user"
:
# Handle the various types of Chat Request content types here.
# Handle the various types of Chat Request content types here.
role
=
conv
.
roles
[
0
]
if
isinstance
(
message
.
content
,
str
):
if
isinstance
(
message
.
content
,
str
):
conv
.
append_message
(
conv
.
roles
[
0
],
message
.
content
)
conv
.
append_message
(
conv
.
roles
[
0
],
message
.
content
)
else
:
else
:
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
ff2ce0b8
...
@@ -66,6 +66,7 @@ def get_config(
...
@@ -66,6 +66,7 @@ def get_config(
config
=
AutoConfig
.
from_pretrained
(
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
)
if
config
.
model_type
in
_CONFIG_REGISTRY
:
if
config
.
model_type
in
_CONFIG_REGISTRY
:
config_class
=
_CONFIG_REGISTRY
[
config
.
model_type
]
config_class
=
_CONFIG_REGISTRY
[
config
.
model_type
]
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
)
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
)
...
...
python/sglang/srt/layers/attention/vision.py
View file @
ff2ce0b8
from
__future__
import
annotations
from
__future__
import
annotations
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -22,47 +22,29 @@ from sglang.srt.layers.quantization import QuantizationConfig
...
@@ -22,47 +22,29 @@ from sglang.srt.layers.quantization import QuantizationConfig
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
def
rotate_half
(
x
:
torch
.
Tensor
,
interleaved
:
bool
=
False
)
->
torch
.
Tensor
:
# Copied from transformers, modeling_qwen2_vl.py
if
not
interleaved
:
def
rotate_half
(
x
):
x1
,
x2
=
x
.
chunk
(
2
,
dim
=-
1
)
"""Rotates half the hidden dims of the input."""
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
else
:
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
x1
,
x2
=
x
[...,
::
2
],
x
[...,
1
::
2
]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
return
rearrange
(
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
),
"... d two -> ... (d two)"
,
two
=
2
)
def
apply_rotary_emb_torch
(
def
apply_rotary_pos_emb_vision
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
interleaved
:
bool
=
False
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
)
->
torch
.
Tensor
:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
orig_q_dtype
=
q
.
dtype
x: (batch_size, seqlen, nheads, headdim)
orig_k_dtype
=
k
.
dtype
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
q
,
k
=
q
.
float
(),
k
.
float
()
"""
ro_dim
=
cos
.
shape
[
-
1
]
*
2
cos
,
sin
=
cos
.
unsqueeze
(
-
2
).
float
(),
sin
.
unsqueeze
(
-
2
).
float
()
assert
ro_dim
<=
x
.
shape
[
-
1
]
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
cos
=
repeat
(
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
cos
,
"... d -> ... 1 (2 d)"
if
not
interleaved
else
"... d -> ... 1 (d 2)"
)
q_embed
=
q_embed
.
to
(
orig_q_dtype
)
sin
=
repeat
(
k_embed
=
k_embed
.
to
(
orig_k_dtype
)
sin
,
"... d -> ... 1 (2 d)"
if
not
interleaved
else
"... d -> ... 1 (d 2)"
)
return
q_embed
,
k_embed
return
torch
.
cat
(
[
x
[...,
:
ro_dim
]
*
cos
+
rotate_half
(
x
[...,
:
ro_dim
],
interleaved
)
*
sin
,
x
[...,
ro_dim
:],
],
dim
=-
1
,
)
def
apply_rotary_pos_emb_vision
(
t
:
torch
.
Tensor
,
freqs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
t_
=
t
.
float
()
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
output
=
apply_rotary_emb_torch
(
t_
,
cos
,
sin
).
type_as
(
t
)
return
output
class
VisionAttention
(
nn
.
Module
):
class
VisionAttention
(
nn
.
Module
):
...
@@ -75,8 +57,8 @@ class VisionAttention(nn.Module):
...
@@ -75,8 +57,8 @@ class VisionAttention(nn.Module):
use_context_forward (bool, default to True):
use_context_forward (bool, default to True):
if ``True``, a flash_attn style attention will be applied
if ``True``, a flash_attn style attention will be applied
Otherwise, a full-sequence attention will be applied.
Otherwise, a full-sequence attention will be applied.
use_full
_precision
_softmax
(bool, default to False):
softmax_in_single
_precision (bool, default to False):
if ``True``, the softmax will be performed in
full
-precision
if ``True``, the softmax will be performed in
single
-precision
Otherwise, it will be performed in half-precision
Otherwise, it will be performed in half-precision
"""
"""
...
@@ -90,7 +72,7 @@ class VisionAttention(nn.Module):
...
@@ -90,7 +72,7 @@ class VisionAttention(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
use_context_forward
:
bool
=
True
,
use_context_forward
:
bool
=
True
,
use_full
_precision
_softmax
:
bool
=
False
,
softmax_in_single
_precision
:
bool
=
False
,
flatten_batch
:
bool
=
False
,
flatten_batch
:
bool
=
False
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
...
@@ -113,7 +95,7 @@ class VisionAttention(nn.Module):
...
@@ -113,7 +95,7 @@ class VisionAttention(nn.Module):
head_size
=
self
.
head_size
,
head_size
=
self
.
head_size
,
dropout
=
dropout
,
dropout
=
dropout
,
flatten_batch
=
flatten_batch
,
flatten_batch
=
flatten_batch
,
use_full
_precision
_
softmax
=
use_full
_precision
_softmax
,
softmax_in_single
_precision
=
softmax
_in_single
_precision
,
)
)
self
.
use_qkv_parallel
=
use_qkv_parallel
self
.
use_qkv_parallel
=
use_qkv_parallel
...
@@ -143,7 +125,7 @@ class VisionAttention(nn.Module):
...
@@ -143,7 +125,7 @@ class VisionAttention(nn.Module):
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
rotary_pos_emb
:
torch
.
Tensor
=
None
,
position_embeddings
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
r
"""
r
"""
...
@@ -151,21 +133,17 @@ class VisionAttention(nn.Module):
...
@@ -151,21 +133,17 @@ class VisionAttention(nn.Module):
x: [b, s, embed_dim]
x: [b, s, embed_dim]
cu_seqlens: [b]
cu_seqlens: [b]
Returns:
Returns:
[s, b,
num_
head
s
* head]
[s, b, head * head
_size
]
"""
"""
bsz
,
s
,
_
=
x
.
shape
bsz
,
s
,
_
=
x
.
shape
head
=
self
.
num_attention_heads_per_partition
if
self
.
use_qkv_parallel
:
if
self
.
use_qkv_parallel
:
# [b, s, embed_dim] --> [b, s, embed_dim]
# [b, s, embed_dim] --> [b, s, embed_dim]
qkv
,
_
=
self
.
qkv_proj
(
x
)
qkv
,
_
=
self
.
qkv_proj
(
x
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=-
1
)
# [b, s, embed_dim] --> [b * s, num_heads, head_size]
# [b, s, embed_dim] --> [b * s, head, head_size]
q
,
k
,
v
=
[
q
,
k
,
v
=
[
x
.
reshape
(
bsz
*
s
,
head
,
-
1
).
contiguous
()
for
x
in
(
q
,
k
,
v
)]
x
.
reshape
(
bsz
*
s
,
self
.
num_attention_heads_per_partition
,
-
1
).
contiguous
()
for
x
in
(
q
,
k
,
v
)
]
else
:
else
:
# [b, s, embed_dim] --> [s, b, embed_dim]
# [b, s, embed_dim] --> [s, b, embed_dim]
x
=
rearrange
(
x
,
"b s ... -> s b ..."
)
x
=
rearrange
(
x
,
"b s ... -> s b ..."
)
...
@@ -173,7 +151,7 @@ class VisionAttention(nn.Module):
...
@@ -173,7 +151,7 @@ class VisionAttention(nn.Module):
qkv
,
_
=
self
.
qkv_proj
(
x
)
qkv
,
_
=
self
.
qkv_proj
(
x
)
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
new_x_shape
=
qkv
.
size
()[:
-
1
]
+
(
new_x_shape
=
qkv
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads_per_partition
,
head
,
3
*
self
.
hidden_size_per_attention_head
,
3
*
self
.
hidden_size_per_attention_head
,
)
)
qkv
=
qkv
.
view
(
*
new_x_shape
)
qkv
=
qkv
.
view
(
*
new_x_shape
)
...
@@ -186,9 +164,12 @@ class VisionAttention(nn.Module):
...
@@ -186,9 +164,12 @@ class VisionAttention(nn.Module):
rearrange
(
x
,
"s b ... -> b s ..."
).
contiguous
()
for
x
in
(
q
,
k
,
v
)
rearrange
(
x
,
"s b ... -> b s ..."
).
contiguous
()
for
x
in
(
q
,
k
,
v
)
]
]
if
rotary_pos_emb
is
not
None
:
if
position_embeddings
is
not
None
:
q
=
apply_rotary_pos_emb_vision
(
q
,
rotary_pos_emb
)
cos
,
sin
=
position_embeddings
k
=
apply_rotary_pos_emb_vision
(
k
,
rotary_pos_emb
)
original_shape
=
q
.
shape
q
,
k
=
q
.
view
(
s
,
head
,
-
1
),
k
.
view
(
s
,
head
,
-
1
)
q
,
k
=
apply_rotary_pos_emb_vision
(
q
,
k
,
cos
,
sin
)
q
,
k
=
q
.
reshape
(
original_shape
),
k
.
reshape
(
original_shape
)
if
self
.
use_qkv_parallel
:
if
self
.
use_qkv_parallel
:
pass
pass
...
@@ -230,12 +211,12 @@ class VisionSdpaAttention(nn.Module):
...
@@ -230,12 +211,12 @@ class VisionSdpaAttention(nn.Module):
head_size
:
int
,
head_size
:
int
,
dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
flatten_batch
:
bool
=
False
,
flatten_batch
:
bool
=
False
,
use_full
_precision
_softmax
:
bool
=
False
,
softmax_in_single
_precision
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
flatten_batch
=
flatten_batch
self
.
flatten_batch
=
flatten_batch
self
.
use_full
_precision
_
softmax
=
use_full
_precision
_softmax
self
.
softmax_in_single
_precision
=
softmax
_in_single
_precision
self
.
dropout
=
dropout
self
.
dropout
=
dropout
@
staticmethod
@
staticmethod
...
@@ -319,14 +300,14 @@ class VisionSdpaAttention(nn.Module):
...
@@ -319,14 +300,14 @@ class VisionSdpaAttention(nn.Module):
)
)
if
attention_mask
is
None
:
if
attention_mask
is
None
:
if
self
.
use_full
_precision
_softmax
:
if
self
.
softmax_in_single
_precision
:
raise
RuntimeError
(
"Empty attention mask"
)
raise
RuntimeError
(
"Empty attention mask"
)
else
:
else
:
attention_mask
=
attention_mask
.
to
(
device
=
q
.
device
)
attention_mask
=
attention_mask
.
to
(
device
=
q
.
device
)
q
,
k
,
v
=
[
rearrange
(
x
,
"(b s) h d -> b h s d"
,
b
=
bsz
)
for
x
in
[
q
,
k
,
v
]]
q
,
k
,
v
=
[
rearrange
(
x
,
"(b s) h d -> b h s d"
,
b
=
bsz
)
for
x
in
[
q
,
k
,
v
]]
if
self
.
use_full
_precision
_softmax
:
if
self
.
softmax_in_single
_precision
:
scale
=
self
.
head_size
**-
0.5
scale
=
self
.
head_size
**-
0.5
k_transposed
=
rearrange
(
k
,
"b h s d -> b h d s"
)
k_transposed
=
rearrange
(
k
,
"b h s d -> b h d s"
)
attn_weights
=
torch
.
matmul
(
q
,
k_transposed
)
*
scale
attn_weights
=
torch
.
matmul
(
q
,
k_transposed
)
*
scale
...
...
python/sglang/srt/managers/image_processor.py
View file @
ff2ce0b8
This diff is collapsed.
Click to expand it.
python/sglang/srt/managers/image_processors/base_image_processor.py
0 → 100644
View file @
ff2ce0b8
import
concurrent
import
concurrent.futures
import
dataclasses
import
multiprocessing
as
mp
import
os
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
import
PIL
import
transformers
from
decord
import
VideoReader
,
cpu
from
PIL
import
Image
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
load_image
global
global_processor
def
get_global_processor
():
global
global_processor
return
global_processor
@
dataclasses
.
dataclass
class
BaseImageProcessorOutput
:
image_hashes
:
list
[
int
]
image_sizes
:
list
[
tuple
[
int
,
int
]]
all_frames
:
[
PIL
.
Image
]
# input_text, with each frame of video/image represented as an image_token
input_text
:
str
class
BaseImageProcessor
(
ABC
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
self
.
hf_config
=
hf_config
self
.
_processor
=
_processor
self
.
server_args
=
server_args
# FIXME: not accurate, model and image specific
self
.
NUM_TOKEN_PER_FRAME
=
330
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
mp_context
=
mp
.
get_context
(
"fork"
),
initargs
=
(
self
,
server_args
,
),
max_workers
=
int
(
os
.
environ
.
get
(
"SGLANG_CPU_COUNT"
,
os
.
cpu_count
())),
)
def
_build_processor
(
self
,
server_args
):
"""Init the global processor for multi modal models."""
from
sglang.srt.hf_transformers_utils
import
get_processor
return
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
@
abstractmethod
async
def
process_images_async
(
self
,
image_data
,
input_text
,
max_req_input_len
,
**
kwargs
):
pass
def
get_estimated_frames_list
(
self
,
image_data
):
"""
estimate the total frame count from all visual input
"""
# Before processing inputs
estimated_frames_list
=
[]
for
image
in
image_data
:
if
isinstance
(
image
,
str
)
and
image
.
startswith
(
"video:"
):
path
=
image
[
len
(
"video:"
)
:]
# Estimate frames for the video
vr
=
VideoReader
(
path
,
ctx
=
cpu
(
0
))
num_frames
=
len
(
vr
)
else
:
# For images, each contributes one frame
num_frames
=
1
estimated_frames_list
.
append
(
num_frames
)
return
estimated_frames_list
@
staticmethod
def
encode_video
(
video_path
,
frame_count_limit
=
None
):
if
not
os
.
path
.
exists
(
video_path
):
logger
.
error
(
f
"Video
{
video_path
}
does not exist"
)
return
[]
if
frame_count_limit
==
0
:
return
[]
def
uniform_sample
(
l
,
n
):
gap
=
len
(
l
)
/
n
idxs
=
[
int
(
i
*
gap
+
gap
/
2
)
for
i
in
range
(
n
)]
return
[
l
[
i
]
for
i
in
idxs
]
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
0
))
sample_fps
=
round
(
vr
.
get_avg_fps
()
/
1
)
# FPS
frame_indices
=
[
i
for
i
in
range
(
0
,
len
(
vr
),
sample_fps
)]
if
frame_count_limit
is
not
None
and
len
(
frame_indices
)
>
frame_count_limit
:
frame_indices
=
uniform_sample
(
frame_indices
,
frame_count_limit
)
frames
=
vr
.
get_batch
(
frame_indices
).
asnumpy
()
frames
=
[
Image
.
fromarray
(
v
.
astype
(
"uint8"
))
for
v
in
frames
]
return
frames
def
load_images
(
self
,
input_ids
:
list
,
image_data
,
image_token
:
str
,
max_req_input_len
:
int
,
return_text
:
Optional
[
bool
]
=
True
,
discard_alpha_channel
:
bool
=
True
,
)
->
BaseImageProcessorOutput
:
"""
Each frame of video/image will be replaced by a single image token
"""
image_hashes
,
image_sizes
=
[],
[]
all_frames
=
[]
new_text_parts
=
[]
if
isinstance
(
input_ids
,
list
)
and
return_text
:
assert
len
(
input_ids
)
and
isinstance
(
input_ids
[
0
],
int
)
input_text
=
self
.
_processor
.
tokenizer
.
decode
(
input_ids
)
else
:
input_text
=
input_ids
if
return_text
:
text_parts
=
input_text
.
split
(
image_token
)
# roughly calculate the max number of frames under the max_req_input_len limit
MAX_NUM_FRAMES
=
30
estimated_frames_list
=
self
.
get_estimated_frames_list
(
image_data
=
image_data
)
total_frame_count
=
sum
(
estimated_frames_list
)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor
=
min
(
1.0
,
MAX_NUM_FRAMES
/
total_frame_count
)
assert
len
(
image_data
)
==
len
(
estimated_frames_list
)
# Process each input with allocated frames
for
image_index
,
(
image
,
estimated_frames
)
in
enumerate
(
zip
(
image_data
,
estimated_frames_list
)
):
if
len
(
all_frames
)
>=
MAX_NUM_FRAMES
:
max_frames_to_process
=
0
else
:
max_frames_to_process
=
max
(
1
,
int
(
estimated_frames
*
scaling_factor
))
if
max_frames_to_process
==
0
:
frames
=
[]
else
:
try
:
if
isinstance
(
image
,
str
)
and
image
.
startswith
(
"video:"
):
path
=
image
[
len
(
"video:"
)
:]
frames
=
BaseImageProcessor
.
encode_video
(
path
,
frame_count_limit
=
max_frames_to_process
)
else
:
raw_image
,
_size
=
load_image
(
image
)
if
discard_alpha_channel
:
raw_image
=
raw_image
.
convert
(
"RGB"
)
frames
=
[
raw_image
]
assert
len
(
frames
)
!=
0
except
FileNotFoundError
as
e
:
print
(
e
)
return
None
image_sizes
+=
[
frames
[
0
].
size
]
*
len
(
frames
)
image_hashes
+=
[
hash
(
image
)]
*
len
(
frames
)
all_frames
+=
frames
if
return_text
:
new_text_parts
.
append
(
text_parts
[
image_index
])
if
max_frames_to_process
!=
0
:
new_text_parts
.
append
(
image_token
*
len
(
frames
))
assert
max_frames_to_process
>=
len
(
frames
)
if
return_text
:
new_text_parts
.
append
(
text_parts
[
-
1
])
input_text
=
""
.
join
(
new_text_parts
)
return
BaseImageProcessorOutput
(
image_hashes
,
image_sizes
,
all_frames
,
input_text
)
class
DummyImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
):
pass
async
def
process_images_async
(
self
,
*
args
,
**
kwargs
):
return
None
def
init_global_processor
(
sglang_image_processor
:
BaseImageProcessor
,
server_args
:
ServerArgs
):
"""Init the global processor for multi-modal models."""
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
global_processor
=
sglang_image_processor
.
_build_processor
(
server_args
=
server_args
)
python/sglang/srt/managers/image_processors/llava.py
0 → 100644
View file @
ff2ce0b8
import
asyncio
from
typing
import
List
,
Optional
,
Union
import
numpy
as
np
from
sglang.srt.managers.image_processor
import
BaseImageProcessor
from
sglang.srt.managers.image_processors.base_image_processor
import
(
get_global_processor
,
)
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.models.llava
import
LlavaMistralForCausalLM
,
LlavaQwenForCausalLM
from
sglang.srt.models.llavavid
import
LlavaVidForCausalLM
from
sglang.srt.utils
import
load_image
,
logger
from
sglang.utils
import
get_exception_traceback
class
LlavaImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
image_data
:
Union
[
str
,
bytes
],
image_aspect_ratio
:
Optional
[
str
]
=
None
,
image_grid_pinpoints
:
Optional
[
str
]
=
None
,
image_processor
=
None
,
):
processor
=
get_global_processor
()
image_processor
=
image_processor
or
processor
.
image_processor
try
:
image
,
image_size
=
load_image
(
image_data
)
if
image_size
is
not
None
:
# It is a video with multiple images
image_hash
=
hash
(
image_data
)
pixel_values
=
image_processor
(
image
)[
"pixel_values"
]
for
_
in
range
(
len
(
pixel_values
)):
pixel_values
[
_
]
=
pixel_values
[
_
].
astype
(
np
.
float16
)
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
return
pixel_values
,
image_hash
,
image_size
else
:
# It is an image
image_hash
=
hash
(
image_data
)
if
image_aspect_ratio
==
"pad"
:
image
=
expand2square
(
image
,
tuple
(
int
(
x
*
255
)
for
x
in
image_processor
.
image_mean
),
)
pixel_values
=
image_processor
(
image
.
convert
(
"RGB"
))[
"pixel_values"
][
0
]
elif
image_aspect_ratio
==
"anyres"
or
(
image_aspect_ratio
is
not
None
and
"anyres_max"
in
image_aspect_ratio
):
pixel_values
=
process_anyres_image
(
image
,
image_processor
,
image_grid_pinpoints
)
else
:
pixel_values
=
image_processor
(
image
)[
"pixel_values"
][
0
]
if
isinstance
(
pixel_values
,
np
.
ndarray
):
pixel_values
=
pixel_values
.
astype
(
np
.
float16
)
return
pixel_values
,
image_hash
,
image
.
size
except
Exception
:
logger
.
error
(
"Exception in TokenizerManager:
\n
"
+
get_exception_traceback
())
async
def
_process_single_image
(
self
,
image_data
:
Union
[
bytes
,
str
],
aspect_ratio
:
str
,
grid_pinpoints
:
str
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
self
.
executor
,
LlavaImageProcessor
.
_process_single_image_task
,
image_data
,
aspect_ratio
,
grid_pinpoints
,
)
else
:
return
self
.
_process_single_image_task
(
image_data
,
aspect_ratio
,
grid_pinpoints
)
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
,
*
args
,
**
kwargs
,
):
if
not
image_data
:
return
None
modalities
=
request_obj
.
modalities
or
[
"image"
]
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
grid_pinpoints
=
(
self
.
hf_config
.
image_grid_pinpoints
if
hasattr
(
self
.
hf_config
,
"image_grid_pinpoints"
)
and
"anyres"
in
aspect_ratio
else
None
)
if
isinstance
(
image_data
,
str
):
image_data
=
[
image_data
]
if
isinstance
(
image_data
,
list
)
and
len
(
image_data
)
>
0
:
if
"multi-images"
in
modalities
or
"video"
in
modalities
:
# Multiple images
aspect_ratio
=
"pad"
# LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values
,
image_hashes
,
image_sizes
=
[],
[],
[]
res
=
[]
for
img_data
in
image_data
:
res
.
append
(
self
.
_process_single_image
(
img_data
,
aspect_ratio
,
grid_pinpoints
)
)
res
=
await
asyncio
.
gather
(
*
res
)
for
pixel_v
,
image_h
,
image_s
in
res
:
pixel_values
.
append
(
pixel_v
)
image_hashes
.
append
(
image_h
)
image_sizes
.
append
(
image_s
)
if
isinstance
(
pixel_values
[
0
],
np
.
ndarray
):
pixel_values
=
np
.
stack
(
pixel_values
,
axis
=
0
)
else
:
# A single image
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
image_data
[
0
],
aspect_ratio
,
grid_pinpoints
)
image_hashes
=
[
image_hash
]
image_sizes
=
[
image_size
]
else
:
raise
ValueError
(
f
"Invalid image data:
{
image_data
}
"
)
return
{
"pixel_values"
:
pixel_values
,
"image_hashes"
:
image_hashes
,
"image_sizes"
:
image_sizes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
}
ImageProcessorMapping
=
{
LlavaVidForCausalLM
:
LlavaImageProcessor
,
LlavaQwenForCausalLM
:
LlavaImageProcessor
,
LlavaMistralForCausalLM
:
LlavaImageProcessor
,
}
python/sglang/srt/managers/image_processors/minicpmv.py
0 → 100644
View file @
ff2ce0b8
import
asyncio
from
typing
import
List
,
Union
from
sglang.srt.managers.image_processor
import
BaseImageProcessor
from
sglang.srt.managers.image_processors.base_image_processor
import
(
get_global_processor
,
)
from
sglang.srt.models.minicpmv
import
MiniCPMV
class
MiniCPMVImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"(<image>./</image>)"
@
staticmethod
def
_process_images_task
(
images
,
input_text
):
processor
=
get_global_processor
()
result
=
processor
.
__call__
(
text
=
input_text
,
images
=
images
,
return_tensors
=
"pt"
)
return
{
"input_ids"
:
result
.
input_ids
,
"pixel_values"
:
result
.
pixel_values
,
"tgt_sizes"
:
result
.
tgt_sizes
,
}
async
def
_process_images
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
MiniCPMVImageProcessor
.
_process_images_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
=
images
,
text
=
input_text
,
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_ids
,
request_obj
,
max_req_input_len
,
):
if
not
image_data
:
return
None
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
base_output
=
self
.
load_images
(
input_ids
,
image_data
,
self
.
IMAGE_TOKEN
,
max_req_input_len
)
if
base_output
is
None
:
return
None
if
len
(
base_output
.
all_frames
)
==
0
:
return
None
res
=
await
self
.
_process_images
(
images
=
base_output
.
all_frames
,
input_text
=
base_output
.
input_text
)
# Collect special token ids
tokenizer
=
self
.
_processor
.
tokenizer
im_start_id
=
tokenizer
.
im_start_id
im_end_id
=
tokenizer
.
im_end_id
if
tokenizer
.
slice_start_id
:
slice_start_id
=
tokenizer
.
slice_start_id
slice_end_id
=
tokenizer
.
slice_end_id
return
{
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
res
[
"pixel_values"
],
"tgt_sizes"
:
res
[
"tgt_sizes"
],
"image_hashes"
:
base_output
.
image_hashes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"im_start_id"
:
im_start_id
,
"im_end_id"
:
im_end_id
,
"slice_start_id"
:
slice_start_id
,
"slice_end_id"
:
slice_end_id
,
}
ImageProcessorMapping
=
{
MiniCPMV
:
MiniCPMVImageProcessor
}
python/sglang/srt/managers/image_processors/mlama.py
0 → 100644
View file @
ff2ce0b8
import
asyncio
from
typing
import
List
,
Union
from
sglang.srt.managers.image_processor
import
BaseImageProcessor
from
sglang.srt.managers.image_processors.base_image_processor
import
(
get_global_processor
,
)
from
sglang.srt.models.mllama
import
MllamaForConditionalGeneration
from
sglang.srt.utils
import
load_image
class
MllamaImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
images
,
input_text
):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return
get_global_processor
()(
images
,
input_text
,
return_tensors
=
"pt"
)
async
def
_process_single_image
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
MllamaImageProcessor
.
_process_single_image_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
,
input_text
,
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
):
if
not
image_data
:
return
None
if
isinstance
(
input_text
,
list
):
assert
len
(
input_text
)
and
isinstance
(
input_text
[
0
],
int
)
input_text
=
self
.
_processor
.
tokenizer
.
decode
(
input_text
)
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
if
len
(
image_data
)
>
0
:
images
=
[
load_image
(
image
)[
0
]
for
image
in
image_data
]
else
:
images
=
load_image
(
image_data
[
0
])[
0
]
image_inputs
=
await
self
.
_process_single_image
(
images
,
input_text
)
image_inputs
[
"image_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
return
image_inputs
ImageProcessorMapping
=
{
MllamaForConditionalGeneration
:
MllamaImageProcessor
}
python/sglang/srt/managers/image_processors/qwen_vl.py
0 → 100644
View file @
ff2ce0b8
import
asyncio
import
math
from
typing
import
List
,
Union
from
PIL
import
Image
from
sglang.srt.managers.image_processor
import
BaseImageProcessor
from
sglang.srt.managers.image_processors.base_image_processor
import
(
get_global_processor
,
)
from
sglang.srt.models.qwen2_5_vl
import
Qwen2_5_VLForConditionalGeneration
from
sglang.srt.models.qwen2_vl
import
Qwen2VLForConditionalGeneration
# Compatible with Qwen2VL and Qwen2_5VL
class
Qwen2_5VLImageProcessor
(
BaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"<|vision_start|><|image_pad|><|vision_end|>"
self
.
IM_START_TOKEN_ID
=
hf_config
.
vision_start_token_id
self
.
IM_END_TOKEN_ID
=
hf_config
.
vision_end_token_id
self
.
image_token_id
=
hf_config
.
image_token_id
self
.
video_token_id
=
hf_config
.
video_token_id
self
.
NUM_TOKEN_PER_FRAME
=
770
self
.
IMAGE_FACTOR
=
28
self
.
MIN_PIXELS
=
4
*
28
*
28
self
.
MAX_PIXELS
=
16384
*
28
*
28
self
.
MAX_PIXELS
=
16384
*
28
*
28
self
.
MAX_RATIO
=
200
@
staticmethod
def
_process_images_task
(
images
,
input_text
,
_hf_config
):
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
result
=
get_global_processor
().
__call__
(
text
=
[
input_text
],
images
=
images
,
padding
=
True
,
return_tensors
=
"pt"
)
return
{
"input_ids"
:
result
.
input_ids
,
"pixel_values"
:
getattr
(
result
,
"pixel_values"
,
None
),
"image_grid_thw"
:
getattr
(
result
,
"image_grid_thw"
,
None
),
"second_per_grid_ts"
:
getattr
(
result
,
"second_per_grid_ts"
,
None
),
"video_grid_thws"
:
getattr
(
result
,
"video_grid_thws"
,
None
),
}
async
def
_process_images
(
self
,
images
,
input_text
)
->
dict
:
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
self
.
executor
,
Qwen2_5VLImageProcessor
.
_process_images_task
,
images
,
input_text
,
self
.
hf_config
,
)
else
:
return
self
.
_process_images_task
(
images
,
input_text
,
self
.
hf_config
)
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_ids
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
,
):
if
not
image_data
:
return
None
if
isinstance
(
image_data
,
str
):
image_data
=
[
image_data
]
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_images
(
input_ids
,
image_data
,
image_token
,
max_req_input_len
,
)
def
smart_resize
(
height
:
int
,
width
:
int
,
factor
:
int
=
self
.
IMAGE_FACTOR
,
min_pixels
:
int
=
self
.
MIN_PIXELS
,
max_pixels
:
int
=
self
.
MAX_PIXELS
,
)
->
tuple
[
int
,
int
]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if
max
(
height
,
width
)
/
min
(
height
,
width
)
>
self
.
MAX_RATIO
:
raise
ValueError
(
f
"absolute aspect ratio must be smaller than
{
self
.
MAX_RATIO
}
, got
{
max
(
height
,
width
)
/
min
(
height
,
width
)
}
"
)
h_bar
=
max
(
factor
,
round_by_factor
(
height
,
factor
))
w_bar
=
max
(
factor
,
round_by_factor
(
width
,
factor
))
if
h_bar
*
w_bar
>
max_pixels
:
beta
=
math
.
sqrt
((
height
*
width
)
/
max_pixels
)
h_bar
=
floor_by_factor
(
height
/
beta
,
factor
)
w_bar
=
floor_by_factor
(
width
/
beta
,
factor
)
elif
h_bar
*
w_bar
<
min_pixels
:
beta
=
math
.
sqrt
(
min_pixels
/
(
height
*
width
))
h_bar
=
ceil_by_factor
(
height
*
beta
,
factor
)
w_bar
=
ceil_by_factor
(
width
*
beta
,
factor
)
return
h_bar
,
w_bar
def
resize_image
(
image
,
size_factor
:
int
=
self
.
IMAGE_FACTOR
)
->
Image
.
Image
:
width
,
height
=
image
.
size
min_pixels
=
self
.
MIN_PIXELS
max_pixels
=
self
.
MAX_PIXELS
resized_height
,
resized_width
=
smart_resize
(
height
,
width
,
factor
=
size_factor
,
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
)
image
=
image
.
resize
((
resized_width
,
resized_height
))
return
image
def
round_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return
round
(
number
/
factor
)
*
factor
def
ceil_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return
math
.
ceil
(
number
/
factor
)
*
factor
def
floor_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return
math
.
floor
(
number
/
factor
)
*
factor
images
=
[
resize_image
(
image
)
for
image
in
base_output
.
all_frames
]
ret
=
await
self
.
_process_images
(
images
,
base_output
.
input_text
)
return
{
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
ret
[
"pixel_values"
],
"image_hashes"
:
base_output
.
image_hashes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"image_grid_thws"
:
ret
[
"image_grid_thw"
],
"video_grid_thws"
:
ret
[
"video_grid_thws"
],
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_token_id"
:
self
.
image_token_id
,
"video_token_id"
:
self
.
video_token_id
,
"second_per_grid_ts"
:
ret
[
"second_per_grid_ts"
],
}
ImageProcessorMapping
=
{
Qwen2VLForConditionalGeneration
:
Qwen2_5VLImageProcessor
,
Qwen2_5_VLForConditionalGeneration
:
Qwen2_5VLImageProcessor
,
}
python/sglang/srt/managers/multi_modality_padding.py
0 → 100644
View file @
ff2ce0b8
from
abc
import
abstractmethod
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.utils
import
logger
class
MultiModalityDataPaddingPattern
:
"""
Data tokens (like image tokens) often need special handling during padding
to maintain model compatibility. This class provides the interface for
implementing different padding strategies for data tokens
"""
@
abstractmethod
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
)
->
List
[
int
]:
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
"""
pass
class
MultiModalityDataPaddingPatternTokenPairs
(
MultiModalityDataPaddingPattern
):
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
"""
def
__init__
(
self
,
data_token_pairs
:
Optional
[
List
[
Tuple
[
int
,
int
]]])
->
None
:
self
.
data_token_id_pairs
=
data_token_pairs
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
)
->
List
[
int
]:
"""
This function will replace the data-tokens inbetween with pad_values accordingly
"""
pad_values
=
image_inputs
.
pad_values
data_token_pairs
=
self
.
data_token_id_pairs
image_inputs
.
image_offsets
=
[]
if
data_token_pairs
is
None
:
data_token_pairs
=
[
image_inputs
.
im_start_id
,
image_inputs
.
im_end_id
]
if
data_token_pairs
is
None
:
logger
.
warning
(
"No data_token_pairs provided, RadixAttention might be influenced."
)
return
input_ids
start_token_ids
=
[
s
for
s
,
_e
in
data_token_pairs
]
end_tokens_ids
=
[
e
for
_s
,
e
in
data_token_pairs
]
# First start token marks new data
data_start_token
=
start_token_ids
[
0
]
padded_ids
=
[]
last_idx
=
0
data_idx
=
-
1
start_indices
=
[
i
for
i
,
x
in
enumerate
(
input_ids
)
if
x
in
start_token_ids
]
end_indices
=
[
i
for
i
,
x
in
enumerate
(
input_ids
)
if
x
in
end_tokens_ids
]
if
len
(
start_indices
)
!=
len
(
end_indices
):
return
input_ids
for
start_idx
,
end_idx
in
zip
(
start_indices
,
end_indices
):
padded_ids
.
extend
(
input_ids
[
last_idx
:
start_idx
+
1
])
if
input_ids
[
start_idx
]
==
data_start_token
:
data_idx
+=
1
image_inputs
.
image_offsets
+=
[
start_idx
]
num_tokens
=
end_idx
-
start_idx
-
1
pad_value
=
pad_values
[
data_idx
]
padded_ids
.
extend
([
pad_value
]
*
num_tokens
)
last_idx
=
end_idx
padded_ids
.
extend
(
input_ids
[
last_idx
:])
assert
len
(
input_ids
)
==
len
(
padded_ids
)
return
padded_ids
class
MultModalityDataPaddingPatternSingleToken
(
MultiModalityDataPaddingPattern
):
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
which needs first to be expanded to multiple tokens, then replaced with their padding values
This strategy should be used when a single data token represents content that should
be expanded to multiple tokens during processing.
"""
def
__init__
(
self
,
num_data_token_calc_func
:
Callable
[[
Tuple
[
int
,
int
,
int
]],
int
]
)
->
None
:
self
.
num_data_token_calc_func
=
num_data_token_calc_func
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
)
->
List
[
int
]:
"""
This function will follow the procedure of:
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
2. the padded data tokens will be replaced with their pad_values
"""
image_grid_thws
=
image_inputs
.
image_grid_thws
pad_values
=
image_inputs
.
pad_values
image_indices
=
[
idx
for
idx
,
token
in
enumerate
(
input_ids
)
if
token
==
image_inputs
.
im_token_id
]
image_inputs
.
image_offsets
=
[]
input_ids_with_image
=
[]
for
image_cnt
,
_
in
enumerate
(
image_grid_thws
):
print
(
f
"image_cnt
{
image_cnt
}
"
)
num_image_tokens
=
self
.
num_data_token_calc_func
(
image_grid_thws
[
image_cnt
])
if
image_cnt
==
0
:
non_image_tokens
=
input_ids
[:
image_indices
[
image_cnt
]]
else
:
non_image_tokens
=
input_ids
[
image_indices
[
image_cnt
-
1
]
+
1
:
image_indices
[
image_cnt
]
]
input_ids_with_image
.
extend
(
non_image_tokens
)
image_inputs
.
image_offsets
.
append
(
len
(
input_ids_with_image
))
pad_ids
=
pad_values
*
(
(
num_image_tokens
+
len
(
pad_values
))
//
len
(
pad_values
)
)
input_ids_with_image
.
extend
(
pad_ids
[:
num_image_tokens
])
input_ids_with_image
.
extend
(
input_ids
[
image_indices
[
-
1
]
+
1
:])
return
input_ids_with_image
python/sglang/srt/managers/schedule_batch.py
View file @
ff2ce0b8
...
@@ -158,15 +158,19 @@ class ImageInputs:
...
@@ -158,15 +158,19 @@ class ImageInputs:
image_grid_thws
:
List
[
Tuple
[
int
,
int
,
int
]]
=
None
image_grid_thws
:
List
[
Tuple
[
int
,
int
,
int
]]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
# MiniCPMV related
# The id of the single-image placeholder token
im_token_id
:
Optional
[
torch
.
Tensor
]
=
None
# All the images in the batch should share the same special image
# All the images in the batch should share the same special image
# bound token ids.
# bound token ids.
im_start_id
:
Optional
[
torch
.
Tensor
]
=
None
im_start_id
:
Optional
[
int
]
=
None
im_end_id
:
Optional
[
torch
.
Tensor
]
=
None
im_end_id
:
Optional
[
int
]
=
None
slice_start_id
:
Optional
[
torch
.
Tensor
]
=
None
slice_start_id
:
Optional
[
int
]
=
None
slice_end_id
:
Optional
[
torch
.
Tensor
]
=
None
slice_end_id
:
Optional
[
int
]
=
None
tgt_sizes
:
Optional
[
list
]
=
None
tgt_sizes
:
Optional
[
list
]
=
None
# denotes the number of valid image tokens in each image
images_emb_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
@
staticmethod
@
staticmethod
def
from_dict
(
obj
:
dict
):
def
from_dict
(
obj
:
dict
):
ret
=
ImageInputs
(
ret
=
ImageInputs
(
...
@@ -186,11 +190,13 @@ class ImageInputs:
...
@@ -186,11 +190,13 @@ class ImageInputs:
"aspect_ratio_ids"
,
"aspect_ratio_ids"
,
"aspect_ratio_mask"
,
"aspect_ratio_mask"
,
"image_grid_thws"
,
"image_grid_thws"
,
"im_token_id"
,
"im_start_id"
,
"im_start_id"
,
"im_end_id"
,
"im_end_id"
,
"slice_start_id"
,
"slice_start_id"
,
"slice_end_id"
,
"slice_end_id"
,
"tgt_sizes"
,
"tgt_sizes"
,
"images_emb_mask"
,
]
]
for
arg
in
optional_args
:
for
arg
in
optional_args
:
if
arg
in
obj
:
if
arg
in
obj
:
...
...
python/sglang/srt/model_loader/weight_utils.py
View file @
ff2ce0b8
...
@@ -455,7 +455,7 @@ def pt_weights_iterator(
...
@@ -455,7 +455,7 @@ def pt_weights_iterator(
disable
=
not
enable_tqdm
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
,
bar_format
=
_BAR_FORMAT
,
):
):
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
,
weights_only
=
True
)
yield
from
state
.
items
()
yield
from
state
.
items
()
del
state
del
state
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
python/sglang/srt/models/minicpmv.py
View file @
ff2ce0b8
...
@@ -41,7 +41,6 @@ from torch import nn
...
@@ -41,7 +41,6 @@ from torch import nn
from
torch.nn.init
import
trunc_normal_
from
torch.nn.init
import
trunc_normal_
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
sglang.srt.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
sglang.srt.layers.activation
import
get_act_fn
from
sglang.srt.layers.activation
import
get_act_fn
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
...
@@ -51,6 +50,9 @@ from sglang.srt.layers.linear import (
...
@@ -51,6 +50,9 @@ from sglang.srt.layers.linear import (
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.multi_modality_padding
import
(
MultiModalityDataPaddingPatternTokenPairs
,
)
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
...
@@ -186,19 +188,16 @@ class Idefics2EncoderLayer(nn.Module):
...
@@ -186,19 +188,16 @@ class Idefics2EncoderLayer(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
num_heads
=
config
.
num_attention_heads
tp_size
=
get_tensor_model_parallel_world_size
()
num_heads_per_partition
=
divide
(
self
.
num_heads
,
tp_size
)
self
.
self_attn
=
VisionAttention
(
self
.
self_attn
=
VisionAttention
(
embed_dim
=
config
.
hidden_size
,
embed_dim
=
config
.
hidden_size
,
num_heads
=
num_heads
_per_partition
,
num_heads
=
self
.
num_heads
,
projection_size
=
config
.
intermediate_size
,
projection_size
=
config
.
intermediate_size
,
use_qkv_parallel
=
True
,
use_qkv_parallel
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
dropout
=
config
.
attention_dropout
,
dropout
=
config
.
attention_dropout
,
use_context_forward
=
False
,
use_context_forward
=
False
,
use_full
_precision
_softmax
=
True
,
softmax_in_single
_precision
=
True
,
flatten_batch
=
False
,
flatten_batch
=
False
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
)
...
@@ -708,21 +707,21 @@ class MiniCPMVBaseModel(nn.Module):
...
@@ -708,21 +707,21 @@ class MiniCPMVBaseModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
pad_values
:
List
[
int
],
pad_values
:
List
[
int
],
im_start_id
:
torch
.
Tensor
,
im_start_id
:
int
,
im_end_id
:
torch
.
Tensor
,
im_end_id
:
int
,
slice_start_id
:
Optional
[
torch
.
Tensor
]
=
None
,
slice_start_id
:
Optional
[
int
]
=
None
,
slice_end_id
:
Optional
[
torch
.
Tensor
]
=
None
,
slice_end_id
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Returns a tensor indicating the bounds (start and end token ids) of the images
Returns a tensor indicating the bounds (start and end token ids) of the images
"""
"""
# All the images in the batch should share the same special image
# All the images in the batch should share the same special image
# bound token ids.
# bound token ids.
start_cond
=
input_ids
==
im_start_id
[
0
]
start_cond
=
input_ids
==
im_start_id
end_cond
=
input_ids
==
im_end_id
[
0
]
end_cond
=
input_ids
==
im_end_id
if
slice_start_id
is
not
None
:
if
slice_start_id
is
not
None
:
start_cond
|=
input_ids
==
slice_start_id
[
0
]
start_cond
|=
input_ids
==
slice_start_id
end_cond
|=
input_ids
==
slice_end_id
[
0
]
end_cond
|=
input_ids
==
slice_end_id
(
image_start_tokens
,)
=
torch
.
where
(
start_cond
)
(
image_start_tokens
,)
=
torch
.
where
(
start_cond
)
image_start_tokens
+=
1
image_start_tokens
+=
1
...
@@ -733,6 +732,8 @@ class MiniCPMVBaseModel(nn.Module):
...
@@ -733,6 +732,8 @@ class MiniCPMVBaseModel(nn.Module):
if
(
if
(
len
(
image_start_tokens
)
+
1
==
len
(
image_end_tokens
)
len
(
image_start_tokens
)
+
1
==
len
(
image_end_tokens
)
and
input_ids
[
0
]
in
pad_values
and
input_ids
[
0
]
in
pad_values
and
len
(
image_start_tokens
)
!=
0
and
len
(
image_end_tokens
)
!=
0
and
image_end_tokens
[
0
]
<
image_start_tokens
[
0
]
and
image_end_tokens
[
0
]
<
image_start_tokens
[
0
]
):
):
image_start_tokens
=
torch
.
cat
(
image_start_tokens
=
torch
.
cat
(
...
@@ -897,9 +898,12 @@ class MiniCPMVBaseModel(nn.Module):
...
@@ -897,9 +898,12 @@ class MiniCPMVBaseModel(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
forward_batch
.
image_inputs
is
not
None
and
forward_batch
.
image_inputs
!=
[
if
(
None
forward_batch
.
image_inputs
is
not
None
]:
and
len
(
forward_batch
.
image_inputs
)
>
0
and
forward_batch
.
image_inputs
[
0
]
is
not
None
):
# TODO: bath
kwargs
.
update
(
kwargs
.
update
(
{
{
"pixel_values"
:
(
"pixel_values"
:
(
...
@@ -1135,81 +1139,16 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
...
@@ -1135,81 +1139,16 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
return
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
return
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
if
not
isinstance
(
image_inputs
.
im_start_id
,
list
)
or
not
isinstance
(
image_inputs
.
im_end_id
,
list
):
return
input_ids
new_input_ids
=
[]
last_idx
=
0
image_idx
=
-
1
image_inputs
.
image_offsets
=
[]
# Get all special token IDs
# Get all special token IDs
im_start_id
=
(
im_start_id
:
int
=
image_inputs
.
im_start_id
image_inputs
.
im_start_id
[
0
].
item
()
im_end_id
:
int
=
image_inputs
.
im_end_id
if
isinstance
(
image_inputs
.
im_start_id
[
0
],
torch
.
Tensor
)
slice_start_id
:
int
=
image_inputs
.
slice_start_id
else
image_inputs
.
im_start_id
[
0
]
slice_end_id
:
int
=
image_inputs
.
slice_end_id
)
im_end_id
=
(
image_inputs
.
im_end_id
[
0
].
item
()
if
isinstance
(
image_inputs
.
im_end_id
[
0
],
torch
.
Tensor
)
else
image_inputs
.
im_end_id
[
0
]
)
slice_start_id
=
(
image_inputs
.
slice_start_id
[
0
].
item
()
if
isinstance
(
image_inputs
.
slice_start_id
[
0
],
torch
.
Tensor
)
else
image_inputs
.
slice_start_id
[
0
]
)
slice_end_id
=
(
image_inputs
.
slice_end_id
[
0
].
item
()
if
isinstance
(
image_inputs
.
slice_end_id
[
0
],
torch
.
Tensor
)
else
image_inputs
.
slice_end_id
[
0
]
)
# Find all start and end positions for both types
start_indices
=
[
i
for
i
,
x
in
enumerate
(
input_ids
)
if
x
==
im_start_id
or
x
==
slice_start_id
]
end_indices
=
[
i
for
i
,
x
in
enumerate
(
input_ids
)
if
x
==
im_end_id
or
x
==
slice_end_id
]
if
len
(
start_indices
)
!=
len
(
end_indices
):
return
input_ids
# Process each region (both image and slice)
for
start_idx
,
end_idx
in
zip
(
start_indices
,
end_indices
):
# Add non-image tokens before this region
new_input_ids
.
extend
(
input_ids
[
last_idx
:
start_idx
+
1
]
)
# include start token
is_image_start
=
input_ids
[
start_idx
]
==
im_start_id
if
is_image_start
:
image_inputs
.
image_offsets
+=
[
start_idx
]
image_idx
+=
1
num_tokens
=
end_idx
-
start_idx
-
1
# exclude start and end tokens
# Generate pad_ids
pad_values
=
[
image_inputs
.
pad_values
[
image_idx
]]
pad_ids
=
pad_values
*
((
num_tokens
+
len
(
pad_values
))
//
len
(
pad_values
))
pad_ids
=
pad_ids
[:
num_tokens
]
# Add pad_ids
new_input_ids
.
extend
(
pad_ids
)
# Update last_idx to after end token
media_token_pairs
=
[(
im_start_id
,
im_end_id
),
(
slice_start_id
,
slice_end_id
)]
last_idx
=
end_idx
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
# Add remaining tokens after last region
return
pattern
.
pad_input_tokens
(
input_ids
,
image_inputs
)
new_input_ids
.
extend
(
input_ids
[
last_idx
:])
assert
len
(
input_ids
)
==
len
(
new_input_ids
)
return
new_input_ids
_SUPPORT_VERSION
=
{(
2
,
6
):
MiniCPMV2_6
}
_SUPPORT_VERSION
=
{(
2
,
6
):
MiniCPMV2_6
}
...
...
python/sglang/srt/models/mllama.py
View file @
ff2ce0b8
...
@@ -202,7 +202,7 @@ class MllamaVisionEncoderLayer(nn.Module):
...
@@ -202,7 +202,7 @@ class MllamaVisionEncoderLayer(nn.Module):
quant_config
=
None
,
quant_config
=
None
,
dropout
=
0.0
,
dropout
=
0.0
,
use_context_forward
=
False
,
use_context_forward
=
False
,
use_full
_precision
_softmax
=
False
,
softmax_in_single
_precision
=
False
,
flatten_batch
=
False
,
flatten_batch
=
False
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
)
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
ff2ce0b8
...
@@ -47,6 +47,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
...
@@ -47,6 +47,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.multi_modality_padding
import
(
MultiModalityDataPaddingPatternTokenPairs
,
)
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
@@ -121,12 +124,12 @@ class Qwen2_5_VisionBlock(nn.Module):
...
@@ -121,12 +124,12 @@ class Qwen2_5_VisionBlock(nn.Module):
self
.
norm2
=
Qwen2RMSNorm
(
dim
,
eps
=
1e-6
)
self
.
norm2
=
Qwen2RMSNorm
(
dim
,
eps
=
1e-6
)
if
attn_implementation
==
"sdpa"
:
if
attn_implementation
==
"sdpa"
:
use_context_forward
=
False
use_context_forward
=
False
use_full
_precision
_softmax
=
False
softmax_in_single
_precision
=
False
elif
attn_implementation
==
"flash_attention_2"
:
elif
attn_implementation
==
"flash_attention_2"
:
use_full
_precision
_softmax
=
False
softmax_in_single
_precision
=
False
use_context_forward
=
True
use_context_forward
=
True
elif
attn_implementation
==
"eager"
:
elif
attn_implementation
==
"eager"
:
use_full
_precision
_softmax
=
True
softmax_in_single
_precision
=
True
use_context_forward
=
False
use_context_forward
=
False
self
.
attn
=
VisionAttention
(
self
.
attn
=
VisionAttention
(
...
@@ -135,7 +138,7 @@ class Qwen2_5_VisionBlock(nn.Module):
...
@@ -135,7 +138,7 @@ class Qwen2_5_VisionBlock(nn.Module):
projection_size
=
dim
,
projection_size
=
dim
,
use_qkv_parallel
=
False
,
use_qkv_parallel
=
False
,
use_context_forward
=
use_context_forward
,
use_context_forward
=
use_context_forward
,
use_full
_precision
_
softmax
=
use_full
_precision
_softmax
,
softmax_in_single
_precision
=
softmax
_in_single
_precision
,
flatten_batch
=
True
,
flatten_batch
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
prefix
=
add_prefix
(
"attn"
,
prefix
),
...
@@ -149,12 +152,17 @@ class Qwen2_5_VisionBlock(nn.Module):
...
@@ -149,12 +152,17 @@ class Qwen2_5_VisionBlock(nn.Module):
)
)
def
forward
(
def
forward
(
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
norm1
(
x
)
hidden_states
=
self
.
norm1
(
x
)
hidden_states
=
rearrange
(
hidden_states
,
"s b ... -> b s ..."
)
hidden_states
=
rearrange
(
hidden_states
,
"s b ... -> b s ..."
)
attn
=
self
.
attn
(
attn
=
self
.
attn
(
hidden_states
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
hidden_states
,
cu_seqlens
=
cu_seqlens
,
position_embeddings
=
position_embeddings
,
)
)
attn
=
rearrange
(
attn
,
"b s ... -> s b ..."
)
attn
=
rearrange
(
attn
,
"b s ... -> s b ..."
)
x
=
x
+
attn
x
=
x
+
attn
...
@@ -443,6 +451,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -443,6 +451,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
)
rotary_pos_emb
=
rotary_pos_emb
[
window_index
,
:,
:]
rotary_pos_emb
=
rotary_pos_emb
[
window_index
,
:,
:]
rotary_pos_emb
=
rotary_pos_emb
.
reshape
(
seq_len
,
-
1
)
rotary_pos_emb
=
rotary_pos_emb
.
reshape
(
seq_len
,
-
1
)
emb
=
torch
.
cat
((
rotary_pos_emb
,
rotary_pos_emb
),
dim
=-
1
)
position_embeddings
=
(
emb
.
cos
(),
emb
.
sin
())
# compute cu_seqlens
# compute cu_seqlens
cu_seqlens
=
torch
.
repeat_interleave
(
cu_seqlens
=
torch
.
repeat_interleave
(
...
@@ -457,7 +467,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -457,7 +467,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens_now
=
cu_seqlens
cu_seqlens_now
=
cu_seqlens
else
:
else
:
cu_seqlens_now
=
cu_window_seqlens
cu_seqlens_now
=
cu_window_seqlens
x
=
blk
(
x
,
cu_seqlens
=
cu_seqlens_now
,
rotary_pos_emb
=
rotary_pos_emb
)
x
=
blk
(
x
,
cu_seqlens
=
cu_seqlens_now
,
position_embeddings
=
position_embeddings
)
# adapter
# adapter
x
=
self
.
merger
(
x
)
x
=
self
.
merger
(
x
)
...
@@ -522,50 +534,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
...
@@ -522,50 +534,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return
num_image_tokens
return
num_image_tokens
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
new_input_ids
=
[]
last_idx
=
0
image_idx
=
-
1
image_inputs
.
image_offsets
=
[]
# Get all special token IDs
# Get all special token IDs
im_start_id
=
image_inputs
.
im_start_id
im_start_id
:
int
=
image_inputs
.
im_start_id
im_end_id
=
image_inputs
.
im_end_id
im_end_id
:
int
=
image_inputs
.
im_end_id
# Find all start and end positions for both types
start_indices
=
[
i
for
i
,
x
in
enumerate
(
input_ids
)
if
x
==
im_start_id
]
end_indices
=
[
i
for
i
,
x
in
enumerate
(
input_ids
)
if
x
==
im_end_id
]
if
len
(
start_indices
)
!=
len
(
end_indices
):
return
input_ids
# Process each region (both image and slice)
for
start_idx
,
end_idx
in
zip
(
start_indices
,
end_indices
):
# Add non-image tokens before this region
new_input_ids
.
extend
(
input_ids
[
last_idx
:
start_idx
+
1
])
is_image_start
=
input_ids
[
start_idx
]
==
im_start_id
if
is_image_start
:
image_inputs
.
image_offsets
+=
[
start_idx
]
image_idx
+=
1
num_tokens
=
end_idx
-
start_idx
-
1
# exclude start and end tokens
# Generate pad_ids
pad_values
=
[
image_inputs
.
pad_values
[
image_idx
]]
pad_ids
=
pad_values
*
((
num_tokens
+
len
(
pad_values
))
//
len
(
pad_values
))
pad_ids
=
pad_ids
[:
num_tokens
]
# Add pad_ids
new_input_ids
.
extend
(
pad_ids
)
# Update last_idx to after end token
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
last_idx
=
end_idx
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
# Add remaining tokens after last region
return
pattern
.
pad_input_tokens
(
input_ids
,
image_inputs
)
new_input_ids
.
extend
(
input_ids
[
last_idx
:])
assert
len
(
input_ids
)
==
len
(
new_input_ids
)
return
new_input_ids
def
_process_image_input
(
self
,
image_input
:
Qwen2VLImageInputs
)
->
torch
.
Tensor
:
def
_process_image_input
(
self
,
image_input
:
Qwen2VLImageInputs
)
->
torch
.
Tensor
:
pixel_values
=
image_input
[
"pixel_values"
].
type
(
self
.
visual
.
dtype
)
pixel_values
=
image_input
[
"pixel_values"
].
type
(
self
.
visual
.
dtype
)
...
@@ -629,7 +605,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
...
@@ -629,7 +605,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
for
i
,
image
in
enumerate
(
forward_batch
.
image_inputs
):
for
i
,
image
in
enumerate
(
forward_batch
.
image_inputs
):
if
image
is
None
:
if
image
is
None
or
image
.
pixel_values
is
None
:
continue
continue
start_idx
=
extend_start_loc_cpu
[
i
]
start_idx
=
extend_start_loc_cpu
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
ff2ce0b8
...
@@ -42,6 +42,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
...
@@ -42,6 +42,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.multi_modality_padding
import
(
MultiModalityDataPaddingPatternTokenPairs
,
)
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
@@ -137,12 +140,12 @@ class Qwen2VisionBlock(nn.Module):
...
@@ -137,12 +140,12 @@ class Qwen2VisionBlock(nn.Module):
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
if
attn_implementation
==
"sdpa"
:
if
attn_implementation
==
"sdpa"
:
use_context_forward
=
False
use_context_forward
=
False
use_full
_precision
_softmax
=
False
softmax_in_single
_precision
=
False
elif
attn_implementation
==
"flash_attention_2"
:
elif
attn_implementation
==
"flash_attention_2"
:
use_full
_precision
_softmax
=
False
softmax_in_single
_precision
=
False
use_context_forward
=
True
use_context_forward
=
True
elif
attn_implementation
==
"eager"
:
elif
attn_implementation
==
"eager"
:
use_full
_precision
_softmax
=
True
softmax_in_single
_precision
=
True
use_context_forward
=
False
use_context_forward
=
False
self
.
attn
=
VisionAttention
(
self
.
attn
=
VisionAttention
(
...
@@ -151,7 +154,7 @@ class Qwen2VisionBlock(nn.Module):
...
@@ -151,7 +154,7 @@ class Qwen2VisionBlock(nn.Module):
projection_size
=
dim
,
projection_size
=
dim
,
use_qkv_parallel
=
False
,
use_qkv_parallel
=
False
,
use_context_forward
=
use_context_forward
,
use_context_forward
=
use_context_forward
,
use_full
_precision
_
softmax
=
use_full
_precision
_softmax
,
softmax_in_single
_precision
=
softmax
_in_single
_precision
,
flatten_batch
=
True
,
flatten_batch
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
prefix
=
add_prefix
(
"attn"
,
prefix
),
...
@@ -165,12 +168,17 @@ class Qwen2VisionBlock(nn.Module):
...
@@ -165,12 +168,17 @@ class Qwen2VisionBlock(nn.Module):
)
)
def
forward
(
def
forward
(
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
norm1
(
x
)
hidden_states
=
self
.
norm1
(
x
)
hidden_states
=
rearrange
(
hidden_states
,
"s b ... -> b s ..."
)
hidden_states
=
rearrange
(
hidden_states
,
"s b ... -> b s ..."
)
attn
=
self
.
attn
(
attn
=
self
.
attn
(
hidden_states
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
hidden_states
,
cu_seqlens
=
cu_seqlens
,
position_embeddings
=
position_embeddings
,
)
)
attn
=
rearrange
(
attn
,
"b s ... -> s b ..."
)
attn
=
rearrange
(
attn
,
"b s ... -> s b ..."
)
x
=
x
+
attn
x
=
x
+
attn
...
@@ -392,7 +400,8 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -392,7 +400,8 @@ class Qwen2VisionTransformer(nn.Module):
# compute position embedding
# compute position embedding
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw
)
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw
)
emb
=
torch
.
cat
((
rotary_pos_emb
,
rotary_pos_emb
),
dim
=-
1
)
position_embeddings
=
(
emb
.
cos
(),
emb
.
sin
())
# compute cu_seqlens
# compute cu_seqlens
cu_seqlens
=
torch
.
repeat_interleave
(
cu_seqlens
=
torch
.
repeat_interleave
(
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
...
@@ -402,7 +411,7 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -402,7 +411,7 @@ class Qwen2VisionTransformer(nn.Module):
# transformers
# transformers
x
=
x
.
unsqueeze
(
1
)
x
=
x
.
unsqueeze
(
1
)
for
blk
in
self
.
blocks
:
for
blk
in
self
.
blocks
:
x
=
blk
(
x
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
)
x
=
blk
(
x
,
cu_seqlens
=
cu_seqlens
,
position_embeddings
=
position_embeddings
)
# adapter
# adapter
x
=
self
.
merger
(
x
)
x
=
self
.
merger
(
x
)
...
@@ -425,40 +434,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -425,40 +434,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
)
)
return
num_image_tokens
return
num_image_tokens
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
image_grid_thws
=
image_inputs
.
image_grid_thws
pad_values
=
image_inputs
.
pad_values
image_indices
=
[
idx
for
idx
,
token
in
enumerate
(
input_ids
)
if
token
==
self
.
config
.
image_token_id
]
image_inputs
.
image_offsets
=
[]
input_ids_with_image
=
[]
for
image_cnt
,
_
in
enumerate
(
image_grid_thws
):
num_image_tokens
=
self
.
calculate_num_image_tokens
(
image_grid_thws
[
image_cnt
]
)
if
image_cnt
==
0
:
non_image_tokens
=
input_ids
[:
image_indices
[
image_cnt
]]
else
:
non_image_tokens
=
input_ids
[
image_indices
[
image_cnt
-
1
]
+
1
:
image_indices
[
image_cnt
]
]
input_ids_with_image
.
extend
(
non_image_tokens
)
image_inputs
.
image_offsets
.
append
(
len
(
input_ids_with_image
))
pad_ids
=
pad_values
*
(
(
num_image_tokens
+
len
(
pad_values
))
//
len
(
pad_values
)
)
input_ids_with_image
.
extend
(
pad_ids
[:
num_image_tokens
])
input_ids_with_image
.
extend
(
input_ids
[
image_indices
[
-
1
]
+
1
:])
return
input_ids_with_image
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Qwen2VLConfig
,
config
:
Qwen2VLConfig
,
...
@@ -494,6 +469,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -494,6 +469,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
# Get all special token IDs
im_start_id
:
int
=
image_inputs
.
im_start_id
im_end_id
:
int
=
image_inputs
.
im_end_id
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
image_inputs
)
def
_process_image_input
(
self
,
image_input
:
Qwen2VLImageInputs
)
->
torch
.
Tensor
:
def
_process_image_input
(
self
,
image_input
:
Qwen2VLImageInputs
)
->
torch
.
Tensor
:
pixel_values
=
image_input
[
"pixel_values"
].
type
(
self
.
visual
.
dtype
)
pixel_values
=
image_input
[
"pixel_values"
].
type
(
self
.
visual
.
dtype
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_input
[
"image_grid_thw"
])
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_input
[
"image_grid_thw"
])
...
@@ -556,12 +542,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -556,12 +542,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
for
i
,
image
in
enumerate
(
forward_batch
.
image_inputs
):
for
i
,
image
in
enumerate
(
forward_batch
.
image_inputs
):
if
image
is
None
:
if
image
is
None
or
image
.
pixel_values
is
None
:
continue
continue
start_idx
=
extend_start_loc_cpu
[
i
]
start_idx
=
extend_start_loc_cpu
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
pixel_values
=
image
.
pixel_values
.
clone
()
pixel_values
=
torch
.
tensor
(
image
.
pixel_values
,
device
=
"cuda"
)
image_grid_thws
=
torch
.
tensor
(
image_grid_thws
=
torch
.
tensor
(
np
.
array
(
image
.
image_grid_thws
),
device
=
"cuda"
np
.
array
(
image
.
image_grid_thws
),
device
=
"cuda"
)
)
...
@@ -579,15 +565,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -579,15 +565,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_grid_thws
[
idx
]
image_grid_thws
[
idx
]
)
)
left_idx
=
start_idx
+
(
image_offset
-
prefix_len
)
left_idx
=
start_idx
+
(
image_offset
-
prefix_len
+
1
)
right_idx
=
(
right_idx
=
left_idx
+
num_image_tokens
start_idx
+
(
image_offset
-
prefix_len
)
+
num_image_tokens
)
inputs_embeds
[
left_idx
:
right_idx
]
=
image_embeds
[
inputs_embeds
[
left_idx
:
right_idx
]
=
image_embeds
[
image_embeds_offset
:
image_embeds_offset
+
num_image_tokens
image_embeds_offset
:
image_embeds_offset
+
num_image_tokens
]
]
image_embeds_offset
+=
num_image_tokens
image_embeds_offset
+=
num_image_tokens
input_ids
=
None
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment