Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
01090e8a
Unverified
Commit
01090e8a
authored
Mar 13, 2025
by
Mick
Committed by
GitHub
Mar 12, 2025
Browse files
model: Support Janus-pro (#3203)
parent
6f43a9b9
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
2957 additions
and
15 deletions
+2957
-15
benchmark/mmmu/eval_utils.py
benchmark/mmmu/eval_utils.py
+2
-0
docs/references/supported_models.md
docs/references/supported_models.md
+1
-0
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+29
-0
python/sglang/srt/configs/__init__.py
python/sglang/srt/configs/__init__.py
+2
-0
python/sglang/srt/configs/janus_pro.py
python/sglang/srt/configs/janus_pro.py
+629
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+19
-12
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+15
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+15
-1
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+1
-1
python/sglang/srt/managers/image_processors/base_image_processor.py
...ang/srt/managers/image_processors/base_image_processor.py
+14
-1
python/sglang/srt/managers/image_processors/janus_pro.py
python/sglang/srt/managers/image_processors/janus_pro.py
+79
-0
python/sglang/srt/models/deepseek_janus_pro.py
python/sglang/srt/models/deepseek_janus_pro.py
+2127
-0
test/srt/test_vision_openai_server.py
test/srt/test_vision_openai_server.py
+24
-0
No files found.
benchmark/mmmu/eval_utils.py
View file @
01090e8a
...
...
@@ -26,6 +26,7 @@ class EvalArgs:
backend
:
str
=
"engine"
seed
:
int
=
42
split
:
str
=
"validation"
# Default setting to make the benchmark available on A100 for most 7B models
image_pixels_limit
:
int
=
4300000
result_filename
:
str
=
""
prompt_format_file
:
str
=
"prompt_format.yaml"
...
...
@@ -38,6 +39,7 @@ class EvalArgs:
parser
.
add_argument
(
"--result-filename"
,
type
=
str
,
default
=
EvalArgs
.
result_filename
)
parser
.
add_argument
(
"--image-pixels-limit"
,
type
=
int
,
default
=
EvalArgs
.
image_pixels_limit
)
...
...
docs/references/supported_models.md
View file @
01090e8a
...
...
@@ -31,6 +31,7 @@
-
Phi-3 / Phi-4
-
Phi-3-Small
-
IBM Granite 3
-
Janus-Pro-1B / Janus-Pro-7B
## Embedding Models
...
...
python/sglang/lang/chat_template.py
View file @
01090e8a
...
...
@@ -230,6 +230,29 @@ register_chat_template(
)
)
register_chat_template
(
ChatTemplate
(
name
=
"janus-pro"
,
default_system_prompt
=
None
,
role_prefix_and_suffix
=
{
"system"
:
(
""
,
""
,
),
"User"
:
(
"<|User|>"
,
""
,
),
"assistant"
:
(
"<|Assistant|>"
,
"<|end▁of▁sentence|>"
,
),
},
stop_str
=
(
"<|end▁of▁sentence|>"
,),
image_token
=
"<image_placeholder>
\n
"
,
)
)
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template
(
ChatTemplate
(
...
...
@@ -384,6 +407,12 @@ def match_deepseek(model_path: str):
return
get_chat_template
(
"deepseek-v3"
)
@
register_chat_template_matching_function
def
match_deepseek_janus_pro
(
model_path
:
str
):
if
"janus"
in
model_path
.
lower
():
return
get_chat_template
(
"janus-pro"
)
@
register_chat_template_matching_function
def
match_dbrx
(
model_path
:
str
):
if
"dbrx"
in
model_path
.
lower
()
and
"instruct"
in
model_path
.
lower
():
...
...
python/sglang/srt/configs/__init__.py
View file @
01090e8a
from
sglang.srt.configs.chatglm
import
ChatGLMConfig
from
sglang.srt.configs.dbrx
import
DbrxConfig
from
sglang.srt.configs.exaone
import
ExaoneConfig
from
sglang.srt.configs.janus_pro
import
MultiModalityConfig
from
sglang.srt.configs.qwen2_5_vl_config
import
(
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
,
...
...
@@ -12,4 +13,5 @@ __all__ = [
"DbrxConfig"
,
"Qwen2_5_VLConfig"
,
"Qwen2_5_VLVisionConfig"
,
"MultiModalityConfig"
,
]
python/sglang/srt/configs/janus_pro.py
0 → 100644
View file @
01090e8a
# Adapted from:
# https://github.com/deepseek-ai/Janus/tree/main/janus/models
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
numpy
as
np
import
PIL
import
torch
from
PIL.Image
import
Image
from
transformers
import
(
AutoImageProcessor
,
AutoProcessor
,
BaseImageProcessor
,
BatchFeature
,
LlamaConfig
,
LlamaTokenizerFast
,
PretrainedConfig
,
ProcessorMixin
,
)
from
transformers.image_utils
import
to_numpy_array
from
sglang.srt.mm_utils
import
expand2square
class
DictToObject
(
dict
):
def
__init__
(
self
,
dictionary
):
super
(
self
).
__init__
(
dictionary
)
for
key
,
value
in
dictionary
.
items
():
if
isinstance
(
value
,
dict
):
value
=
DictToObject
(
value
)
setattr
(
self
,
key
,
value
)
class
VisionConfig
(
PretrainedConfig
):
model_type
=
"vision"
cls
:
str
=
""
params
=
{}
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
cls
=
kwargs
.
get
(
"cls"
,
""
)
if
not
isinstance
(
self
.
cls
,
str
):
self
.
cls
=
self
.
cls
.
__name__
self
.
params
=
kwargs
.
get
(
"params"
,
{})
class
GenAlignerConfig
(
PretrainedConfig
):
model_type
=
"gen_aligner"
cls
:
str
=
""
params
=
{}
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
cls
=
kwargs
.
get
(
"cls"
,
""
)
if
not
isinstance
(
self
.
cls
,
str
):
self
.
cls
=
self
.
cls
.
__name__
self
.
params
=
kwargs
.
get
(
"params"
,
{})
class
GenHeadConfig
(
PretrainedConfig
):
model_type
=
"gen_head"
cls
:
str
=
""
params
=
{}
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
cls
=
kwargs
.
get
(
"cls"
,
""
)
if
not
isinstance
(
self
.
cls
,
str
):
self
.
cls
=
self
.
cls
.
__name__
self
.
params
=
kwargs
.
get
(
"params"
,
{})
class
AlignerConfig
(
PretrainedConfig
):
model_type
=
"aligner"
cls
:
str
=
""
params
=
{}
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
cls
=
kwargs
.
get
(
"cls"
,
""
)
if
not
isinstance
(
self
.
cls
,
str
):
self
.
cls
=
self
.
cls
.
__name__
self
.
params
=
kwargs
.
get
(
"params"
,
{})
class
GenVisionConfig
(
PretrainedConfig
):
model_type
=
"gen_vision"
cls
:
str
=
""
params
=
{}
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
cls
=
kwargs
.
get
(
"cls"
,
""
)
if
not
isinstance
(
self
.
cls
,
str
):
self
.
cls
=
self
.
cls
.
__name__
self
.
params
=
kwargs
.
get
(
"params"
,
{})
@
dataclass
class
SigLIPVisionCfg
:
width
:
int
=
1152
layers
:
Union
[
Tuple
[
int
,
int
,
int
,
int
],
int
]
=
27
heads
:
int
=
16
patch_size
:
int
=
14
image_size
:
Union
[
Tuple
[
int
,
int
],
int
]
=
336
global_pool
:
str
=
"map"
mlp_ratio
:
float
=
3.7362
class_token
:
bool
=
False
num_classes
:
int
=
0
use_checkpoint
:
bool
=
False
class
MultiModalityConfig
(
PretrainedConfig
):
model_type
=
"multi_modality"
vision_config
:
VisionConfig
aligner_config
:
AlignerConfig
gen_vision_config
:
GenVisionConfig
gen_aligner_config
:
GenAlignerConfig
gen_head_config
:
GenHeadConfig
language_config
:
LlamaConfig
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
vision_config
=
kwargs
.
get
(
"vision_config"
,
{})
self
.
vision_config
=
VisionConfig
(
**
vision_config
)
aligner_config
=
kwargs
.
get
(
"aligner_config"
,
{})
self
.
aligner_config
=
AlignerConfig
(
**
aligner_config
)
gen_vision_config
=
kwargs
.
get
(
"gen_vision_config"
,
{})
self
.
gen_vision_config
=
GenVisionConfig
(
**
gen_vision_config
)
gen_aligner_config
=
kwargs
.
get
(
"gen_aligner_config"
,
{})
self
.
gen_aligner_config
=
GenAlignerConfig
(
**
gen_aligner_config
)
gen_head_config
=
kwargs
.
get
(
"gen_head_config"
,
{})
self
.
gen_head_config
=
GenHeadConfig
(
**
gen_head_config
)
language_config
=
kwargs
.
get
(
"language_config"
,
{})
if
isinstance
(
language_config
,
LlamaConfig
):
self
.
language_config
=
language_config
else
:
self
.
language_config
=
LlamaConfig
(
**
language_config
)
class
VLMImageProcessor
(
BaseImageProcessor
):
model_input_names
=
[
"pixel_values"
]
def
__init__
(
self
,
image_size
:
int
,
min_size
:
int
=
14
,
image_mean
:
Union
[
Tuple
[
float
,
float
,
float
],
List
[
float
]]
=
(
0.48145466
,
0.4578275
,
0.40821073
,
),
image_std
:
Union
[
Tuple
[
float
,
float
,
float
],
List
[
float
]]
=
(
0.26862954
,
0.26130258
,
0.27577711
,
),
rescale_factor
:
float
=
1.0
/
255.0
,
do_normalize
:
bool
=
True
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
image_size
=
image_size
self
.
rescale_factor
=
rescale_factor
self
.
image_mean
=
image_mean
self
.
image_std
=
image_std
self
.
min_size
=
min_size
self
.
do_normalize
=
do_normalize
if
image_mean
is
None
:
self
.
background_color
=
(
127
,
127
,
127
)
else
:
self
.
background_color
=
tuple
([
int
(
x
*
255
)
for
x
in
image_mean
])
def
resize
(
self
,
pil_img
:
Image
)
->
np
.
ndarray
:
"""
Args:
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
Returns:
x (np.ndarray): [3, self.image_size, self.image_size]
"""
width
,
height
=
pil_img
.
size
max_size
=
max
(
width
,
height
)
size
=
[
max
(
int
(
height
/
max_size
*
self
.
image_size
),
self
.
min_size
),
max
(
int
(
width
/
max_size
*
self
.
image_size
),
self
.
min_size
),
]
if
width
<=
0
or
height
<=
0
or
size
[
0
]
<=
0
or
size
[
1
]
<=
0
:
# print(f"orig size = {pil_img.size}, new size = {size}")
raise
ValueError
(
"Invalid size!"
)
def
resize
(
pil_img
,
size
,
interpolation
=
PIL
.
Image
.
Resampling
.
BICUBIC
,
antialias
=
True
):
if
isinstance
(
size
,
int
):
w
,
h
=
pil_img
.
size
if
(
w
<=
h
and
w
==
size
)
or
(
h
<=
w
and
h
==
size
):
return
pil_img
if
w
<
h
:
ow
=
size
oh
=
int
(
size
*
h
/
w
)
else
:
oh
=
size
ow
=
int
(
size
*
w
/
h
)
size
=
(
ow
,
oh
)
else
:
size
=
(
size
[
1
],
size
[
0
])
return
pil_img
.
resize
(
size
,
resample
=
interpolation
,
reducing_gap
=
None
if
antialias
else
3.0
)
pil_img
=
resize
(
pil_img
,
size
,
interpolation
=
PIL
.
Image
.
Resampling
.
BICUBIC
,
antialias
=
True
)
pil_img
=
expand2square
(
pil_img
,
self
.
background_color
)
x
=
to_numpy_array
(
pil_img
)
# [H, W, 3] -> [3, H, W]
x
=
np
.
transpose
(
x
,
(
2
,
0
,
1
))
return
x
def
preprocess
(
self
,
images
,
return_tensors
:
str
=
"pt"
,
**
kwargs
)
->
BatchFeature
:
# resize and pad to [self.image_size, self.image_size]
# then convert from [H, W, 3] to [3, H, W]
if
not
isinstance
(
images
,
list
):
images
=
[
images
]
images
:
List
[
np
.
ndarray
]
=
[
self
.
resize
(
image
)
for
image
in
images
]
images
=
[
image
[:
3
,
...]
for
image
in
images
]
# rescale from [0, 255] -> [0, 1]
images
=
[
self
.
rescale
(
image
=
image
,
scale
=
self
.
rescale_factor
,
input_data_format
=
"channels_first"
,
)
for
image
in
images
]
# normalize
if
self
.
do_normalize
:
images
=
[
self
.
normalize
(
image
=
image
,
mean
=
self
.
image_mean
,
std
=
self
.
image_std
,
input_data_format
=
"channels_first"
,
)
for
image
in
images
]
data
=
{
"pixel_values"
:
images
}
return
BatchFeature
(
data
=
data
,
tensor_type
=
return_tensors
)
@
property
def
default_shape
(
self
):
return
[
3
,
self
.
image_size
,
self
.
image_size
]
class
DictOutput
(
object
):
def
keys
(
self
):
return
self
.
__dict__
.
keys
()
def
__getitem__
(
self
,
item
):
return
self
.
__dict__
[
item
]
def
__setitem__
(
self
,
key
,
value
):
self
.
__dict__
[
key
]
=
value
@
dataclass
class
VLChatProcessorOutput
(
DictOutput
):
sft_format
:
str
input_ids
:
torch
.
Tensor
pixel_values
:
torch
.
Tensor
num_image_tokens
:
torch
.
IntTensor
def
__len__
(
self
):
return
len
(
self
.
input_ids
)
@
dataclass
class
BatchedVLChatProcessorOutput
(
DictOutput
):
sft_format
:
List
[
str
]
input_ids
:
torch
.
Tensor
pixel_values
:
torch
.
Tensor
attention_mask
:
torch
.
Tensor
images_seq_mask
:
torch
.
BoolTensor
images_emb_mask
:
torch
.
BoolTensor
# FIXME: had to place Official Processor here, since image_processor module would not be imported in all threads,
# hence AutoProcessor registration would not be affective in some cases
class
VLChatProcessor
(
ProcessorMixin
):
image_processor_class
=
"AutoImageProcessor"
tokenizer_class
=
(
"LlamaTokenizer"
,
"LlamaTokenizerFast"
)
attributes
=
[
"image_processor"
,
"tokenizer"
]
def
__init__
(
self
,
image_processor
:
VLMImageProcessor
,
tokenizer
:
LlamaTokenizerFast
,
image_tag
:
str
=
"<image_placeholder>"
,
image_start_tag
:
str
=
"<begin_of_image>"
,
image_end_tag
:
str
=
"<end_of_image>"
,
pad_tag
:
str
=
"<|▁pad▁|>"
,
num_image_tokens
:
int
=
576
,
add_special_token
:
bool
=
False
,
sft_format
:
str
=
"deepseek"
,
mask_prompt
:
bool
=
True
,
ignore_id
:
int
=
-
100
,
**
kwargs
,
):
self
.
image_processor
=
image_processor
self
.
tokenizer
=
tokenizer
image_id
=
self
.
tokenizer
.
vocab
.
get
(
image_tag
)
if
image_id
is
None
:
special_tokens
=
[
image_tag
]
special_tokens_dict
=
{
"additional_special_tokens"
:
special_tokens
}
self
.
tokenizer
.
add_special_tokens
(
special_tokens_dict
)
# print(f"Add image tag = {image_tag} to the tokenizer")
self
.
image_tag
=
image_tag
self
.
image_start_tag
=
image_start_tag
self
.
image_end_tag
=
image_end_tag
self
.
pad_tag
=
pad_tag
self
.
num_image_tokens
=
num_image_tokens
self
.
add_special_token
=
add_special_token
self
.
sft_format
=
sft_format
self
.
ignore_id
=
ignore_id
super
().
__init__
(
image_processor
,
tokenizer
,
**
kwargs
,
)
@
property
def
image_token
(
self
):
return
self
.
image_tag
@
property
def
image_id
(
self
)
->
int
:
image_id
=
self
.
tokenizer
.
vocab
.
get
(
self
.
image_tag
)
return
image_id
@
property
def
image_start_id
(
self
):
image_start_id
=
self
.
tokenizer
.
vocab
.
get
(
self
.
image_start_tag
)
return
image_start_id
@
property
def
image_end_id
(
self
):
image_end_id
=
self
.
tokenizer
.
vocab
.
get
(
self
.
image_end_tag
)
return
image_end_id
@
property
def
image_start_token
(
self
):
return
self
.
image_start_tag
@
property
def
image_end_token
(
self
):
return
self
.
image_end_tag
@
property
def
pad_id
(
self
):
pad_id
=
self
.
tokenizer
.
vocab
.
get
(
self
.
pad_tag
)
return
pad_id
def
add_image_token
(
self
,
image_indices
:
List
[
int
],
input_ids
:
torch
.
LongTensor
,
):
"""
Args:
image_indices (List[int]): [index_0, index_1, ..., index_j]
input_ids (torch.LongTensor): [N]
Returns:
input_ids (torch.LongTensor): [N + image tokens]
num_image_tokens (torch.IntTensor): [n_images]
"""
input_slices
=
[]
start
=
0
for
index
in
image_indices
:
if
self
.
add_special_token
:
end
=
index
+
1
else
:
end
=
index
# original text tokens
input_slices
.
append
(
input_ids
[
start
:
end
])
# add boi, image tokens, eoi and set the mask as False
input_slices
.
append
(
self
.
image_start_id
*
torch
.
ones
((
1
),
dtype
=
torch
.
long
))
input_slices
.
append
(
self
.
image_id
*
torch
.
ones
((
self
.
num_image_tokens
,),
dtype
=
torch
.
long
)
)
input_slices
.
append
(
self
.
image_end_id
*
torch
.
ones
((
1
),
dtype
=
torch
.
long
))
start
=
index
+
1
# the left part
input_slices
.
append
(
input_ids
[
start
:])
# concat all slices
input_ids
=
torch
.
cat
(
input_slices
,
dim
=
0
)
num_image_tokens
=
torch
.
IntTensor
([
self
.
num_image_tokens
]
*
len
(
image_indices
))
return
input_ids
,
num_image_tokens
def
process_one
(
self
,
prompt
:
str
=
None
,
images
:
List
[
Image
]
=
None
,
**
kwargs
,
):
"""
Args:
prompt (str): the formatted prompt;
images (List[ImageType]): the list of images;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
sft_format
=
prompt
# tokenize
input_ids
=
self
.
tokenizer
.
encode
(
sft_format
)
input_ids
=
torch
.
LongTensor
(
input_ids
)
# add image tokens to the input_ids
image_token_mask
:
torch
.
Tensor
=
(
input_ids
==
self
.
image_id
).
to
(
torch
.
bool
)
image_indices
=
image_token_mask
.
nonzero
()
input_ids
,
num_image_tokens
=
self
.
add_image_token
(
image_indices
=
image_indices
,
input_ids
=
input_ids
,
)
# load images
images_outputs
=
self
.
image_processor
(
images
,
return_tensors
=
"pt"
)
prepare
=
VLChatProcessorOutput
(
sft_format
=
sft_format
,
input_ids
=
input_ids
,
pixel_values
=
images_outputs
.
pixel_values
,
num_image_tokens
=
num_image_tokens
,
)
return
prepare
def
__call__
(
self
,
*
,
prompt
:
str
=
None
,
conversations
:
List
[
Dict
[
str
,
str
]]
=
None
,
images
:
List
[
Image
]
=
None
,
force_batchify
:
bool
=
True
,
**
kwargs
,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
force_batchify (bool): force batchify the inputs;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
prepare
=
self
.
process_one
(
prompt
=
prompt
,
conversations
=
conversations
,
images
=
images
)
if
force_batchify
:
prepare
=
self
.
batchify
([
prepare
])
return
prepare
def
batchify
(
self
,
prepare_list
:
List
[
VLChatProcessorOutput
]
)
->
BatchedVLChatProcessorOutput
:
"""
Preprocesses the inputs for multimodal inference.
Args:
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
Returns:
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
"""
batch_size
=
len
(
prepare_list
)
sft_format
=
[]
n_images
=
[]
seq_lens
=
[]
for
prepare
in
prepare_list
:
n_images
.
append
(
len
(
prepare
.
num_image_tokens
))
seq_lens
.
append
(
len
(
prepare
))
input_token_max_len
=
max
(
seq_lens
)
max_n_images
=
max
(
1
,
max
(
n_images
))
batched_input_ids
=
torch
.
full
(
(
batch_size
,
input_token_max_len
),
self
.
pad_id
).
long
()
# FIXME
batched_attention_mask
=
torch
.
zeros
((
batch_size
,
input_token_max_len
)).
long
()
batched_pixel_values
=
torch
.
zeros
(
(
batch_size
,
max_n_images
,
*
self
.
image_processor
.
default_shape
)
).
float
()
batched_images_seq_mask
=
torch
.
zeros
((
batch_size
,
input_token_max_len
)).
bool
()
batched_images_emb_mask
=
torch
.
zeros
(
(
batch_size
,
max_n_images
,
self
.
num_image_tokens
)
).
bool
()
for
i
,
prepare
in
enumerate
(
prepare_list
):
input_ids
=
prepare
.
input_ids
seq_len
=
len
(
prepare
)
n_image
=
len
(
prepare
.
num_image_tokens
)
# left-padding
batched_attention_mask
[
i
,
-
seq_len
:]
=
1
batched_input_ids
[
i
,
-
seq_len
:]
=
torch
.
LongTensor
(
input_ids
)
batched_images_seq_mask
[
i
,
-
seq_len
:]
=
input_ids
==
self
.
image_id
if
n_image
>
0
:
batched_pixel_values
[
i
,
:
n_image
]
=
prepare
.
pixel_values
for
j
,
n_image_tokens
in
enumerate
(
prepare
.
num_image_tokens
):
batched_images_emb_mask
[
i
,
j
,
:
n_image_tokens
]
=
True
sft_format
.
append
(
prepare
.
sft_format
)
batched_prepares
=
BatchedVLChatProcessorOutput
(
input_ids
=
batched_input_ids
,
attention_mask
=
batched_attention_mask
,
pixel_values
=
batched_pixel_values
,
images_seq_mask
=
batched_images_seq_mask
,
images_emb_mask
=
batched_images_emb_mask
,
sft_format
=
sft_format
,
)
return
batched_prepares
class
VLMImageProcessorConfig
(
PretrainedConfig
):
model_type
=
"deepseek_vlm"
image_size
:
int
min_size
:
int
image_mean
:
Union
[
Tuple
[
float
,
float
,
float
],
List
[
float
]]
image_std
:
Union
[
Tuple
[
float
,
float
,
float
],
List
[
float
]]
rescale_factor
:
float
do_normalize
:
bool
def
__init__
(
self
,
image_size
:
int
,
min_size
:
int
=
14
,
image_mean
:
Union
[
Tuple
[
float
,
float
,
float
],
List
[
float
]]
=
(
0.48145466
,
0.4578275
,
0.40821073
,
),
image_std
:
Union
[
Tuple
[
float
,
float
,
float
],
List
[
float
]]
=
(
0.26862954
,
0.26130258
,
0.27577711
,
),
rescale_factor
:
float
=
1.0
/
255.0
,
do_normalize
:
bool
=
True
,
**
kwargs
,
):
self
.
image_size
=
image_size
self
.
min_size
=
min_size
self
.
image_mean
=
image_mean
self
.
image_std
=
image_std
self
.
rescale_factor
=
rescale_factor
self
.
do_normalize
=
do_normalize
super
().
__init__
(
**
kwargs
)
AutoProcessor
.
register
(
MultiModalityConfig
,
VLChatProcessor
,
exist_ok
=
True
)
AutoImageProcessor
.
register
(
VLMImageProcessorConfig
,
None
,
VLMImageProcessor
,
None
)
python/sglang/srt/configs/model_config.py
View file @
01090e8a
...
...
@@ -408,7 +408,7 @@ def _get_and_verify_dtype(
def
is_generation_model
(
model_architectures
:
List
[
str
],
is_embedding
:
bool
=
False
):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 1. Check the model architectu
r
e
# 2. check the `is_embedding` server args
if
(
...
...
@@ -424,18 +424,25 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
return
not
is_embedding
multimodal_model_archs
=
[
"LlavaLlamaForCausalLM"
,
"LlavaQwenForCausalLM"
,
"LlavaMistralForCausalLM"
,
"LlavaVidForCausalLM"
,
"Grok1VForCausalLM"
,
"Grok1AForCausalLM"
,
"MllamaForConditionalGeneration"
,
"Qwen2VLForConditionalGeneration"
,
"Qwen2_5_VLForConditionalGeneration"
,
"MiniCPMV"
,
"MultiModalityCausalLM"
,
]
def
is_multimodal_model
(
model_architectures
:
List
[
str
]):
if
(
"LlavaLlamaForCausalLM"
in
model_architectures
or
"LlavaQwenForCausalLM"
in
model_architectures
or
"LlavaMistralForCausalLM"
in
model_architectures
or
"LlavaVidForCausalLM"
in
model_architectures
or
"Grok1VForCausalLM"
in
model_architectures
or
"Grok1AForCausalLM"
in
model_architectures
or
"MllamaForConditionalGeneration"
in
model_architectures
or
"Qwen2VLForConditionalGeneration"
in
model_architectures
or
"Qwen2_5_VLForConditionalGeneration"
in
model_architectures
or
"MiniCPMV"
in
model_architectures
if
any
(
multi_model_arch
in
model_architectures
for
multi_model_arch
in
multimodal_model_archs
):
return
True
else
:
...
...
python/sglang/srt/conversation.py
View file @
01090e8a
...
...
@@ -631,3 +631,18 @@ register_conv_template(
image_token
=
"(<image>./</image>)"
,
)
)
# Reference: https://github.com/deepseek-ai/Janus?tab=readme-ov-file#janus-pro
register_conv_template
(
Conversation
(
name
=
"janus-pro"
,
system_message
=
"You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language"
,
system_template
=
"{system_message}."
,
roles
=
(
"User"
,
"Assistant"
),
sep
=
"
\n\n
"
,
sep2
=
"<|end▁of▁sentence|>"
,
sep_style
=
SeparatorStyle
.
ADD_COLON_TWO
,
stop_str
=
[
"<|User|>"
,
"<|end▁of▁sentence|>"
],
image_token
=
"<image_placeholder>"
,
)
)
python/sglang/srt/hf_transformers_utils.py
View file @
01090e8a
...
...
@@ -30,13 +30,20 @@ from transformers import (
)
from
transformers.models.auto.modeling_auto
import
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from
sglang.srt.configs
import
ChatGLMConfig
,
DbrxConfig
,
ExaoneConfig
,
Qwen2_5_VLConfig
from
sglang.srt.configs
import
(
ChatGLMConfig
,
DbrxConfig
,
ExaoneConfig
,
MultiModalityConfig
,
Qwen2_5_VLConfig
,
)
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
DbrxConfig
.
model_type
:
DbrxConfig
,
ExaoneConfig
.
model_type
:
ExaoneConfig
,
Qwen2_5_VLConfig
.
model_type
:
Qwen2_5_VLConfig
,
MultiModalityConfig
.
model_type
:
MultiModalityConfig
,
}
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
...
...
@@ -67,6 +74,13 @@ def get_config(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
# FIXME: Pour contents of janus-pro's langauge_config to first-level
if
isinstance
(
model
,
str
)
and
model
.
lower
().
startswith
(
"deepseek-ai/janus-pro"
):
assert
hasattr
(
config
,
"language_config"
)
for
key
,
val
in
config
.
language_config
.
__dict__
.
items
():
setattr
(
config
,
key
,
val
)
setattr
(
config
,
"architectures"
,
[
"MultiModalityCausalLM"
])
if
config
.
model_type
in
_CONFIG_REGISTRY
:
config_class
=
_CONFIG_REGISTRY
[
config
.
model_type
]
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
)
...
...
python/sglang/srt/layers/attention/vision.py
View file @
01090e8a
...
...
@@ -6,7 +6,7 @@ from typing import Optional, Tuple
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
from
sglang.srt.distributed
import
parallel_state
from
sglang.srt.distributed
import
utils
as
dist_utils
...
...
python/sglang/srt/managers/image_processors/base_image_processor.py
View file @
01090e8a
...
...
@@ -13,6 +13,7 @@ from PIL import Image
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
load_image
from
sglang.utils
import
logger
global
global_processor
...
...
@@ -22,6 +23,13 @@ def get_global_processor():
return
global_processor
def
init_global_processor
(
sglang_image_processor
,
server_args
:
ServerArgs
):
"""Init the global processor for multi-modal models."""
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
global_processor
=
sglang_image_processor
.
_build_processor
(
server_args
=
server_args
)
@
dataclasses
.
dataclass
class
BaseImageProcessorOutput
:
image_hashes
:
list
[
int
]
...
...
@@ -119,6 +127,11 @@ class BaseImageProcessor(ABC):
)
->
BaseImageProcessorOutput
:
"""
Each frame of video/image will be replaced by a single image token
Args:
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
image_hashes
,
image_sizes
=
[],
[]
all_frames
=
[]
...
...
@@ -133,7 +146,7 @@ class BaseImageProcessor(ABC):
if
return_text
:
text_parts
=
input_text
.
split
(
image_token
)
#
roughly calculate the max number of frames under the max_req_input_len limit
#
TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES
=
30
estimated_frames_list
=
self
.
get_estimated_frames_list
(
image_data
=
image_data
)
total_frame_count
=
sum
(
estimated_frames_list
)
...
...
python/sglang/srt/managers/image_processors/janus_pro.py
0 → 100644
View file @
01090e8a
import
asyncio
from
typing
import
List
,
Union
from
sglang.srt.managers.image_processors.base_image_processor
import
(
BaseImageProcessor
as
SGLangBaseImageProcessor
,
)
from
sglang.srt.managers.image_processors.base_image_processor
import
(
get_global_processor
,
)
from
sglang.srt.models.deepseek_janus_pro
import
MultiModalityCausalLM
class
JanusProProcessor
(
SGLangBaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_images_task
(
images
,
input_text
):
processor
=
get_global_processor
()
result
=
processor
.
__call__
(
prompt
=
input_text
,
images
=
images
,
return_tensors
=
"pt"
)
return
{
"input_ids"
:
result
[
"input_ids"
],
"pixel_values"
:
result
[
"pixel_values"
],
"images_emb_mask"
:
result
[
"images_emb_mask"
],
"im_start_id"
:
processor
.
image_start_id
,
"im_end_id"
:
processor
.
image_end_id
,
"im_token_id"
:
processor
.
image_id
,
}
async
def
_process_images
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
JanusProProcessor
.
_process_images_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
=
images
,
text
=
input_text
,
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_ids
,
request_obj
,
max_req_input_len
,
**
kwargs
,
):
if
not
image_data
:
return
None
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
base_out
=
self
.
load_images
(
input_ids
,
image_data
,
"<image_placeholder>"
,
max_req_input_len
)
images
=
base_out
.
all_frames
res
=
await
self
.
_process_images
(
images
=
images
,
input_text
=
base_out
.
input_text
)
return
{
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
res
[
"pixel_values"
],
"images_emb_mask"
:
res
[
"images_emb_mask"
],
"image_hashes"
:
base_out
.
image_hashes
,
"im_start_id"
:
res
[
"im_start_id"
],
"im_end_id"
:
res
[
"im_end_id"
],
"im_token_id"
:
res
[
"im_token_id"
],
}
ImageProcessorMapping
=
{
MultiModalityCausalLM
:
JanusProProcessor
}
python/sglang/srt/models/deepseek_janus_pro.py
0 → 100644
View file @
01090e8a
This diff is collapsed.
Click to expand it.
test/srt/test_vision_openai_server.py
View file @
01090e8a
...
...
@@ -512,5 +512,29 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
cls
.
base_url
+=
"/v1"
class
TestJanusProServer
(
TestOpenAIVisionServer
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"deepseek-ai/Janus-Pro-7B"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--chat-template"
,
"janus-pro"
,
"--mem-fraction-static"
,
"0.4"
,
],
)
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
pass
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment