Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
01090e8a
Unverified
Commit
01090e8a
authored
Mar 13, 2025
by
Mick
Committed by
GitHub
Mar 12, 2025
Browse files
model: Support Janus-pro (#3203)
parent
6f43a9b9
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
2957 additions
and
15 deletions
+2957
-15
benchmark/mmmu/eval_utils.py
benchmark/mmmu/eval_utils.py
+2
-0
docs/references/supported_models.md
docs/references/supported_models.md
+1
-0
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+29
-0
python/sglang/srt/configs/__init__.py
python/sglang/srt/configs/__init__.py
+2
-0
python/sglang/srt/configs/janus_pro.py
python/sglang/srt/configs/janus_pro.py
+629
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+19
-12
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+15
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+15
-1
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+1
-1
python/sglang/srt/managers/image_processors/base_image_processor.py
...ang/srt/managers/image_processors/base_image_processor.py
+14
-1
python/sglang/srt/managers/image_processors/janus_pro.py
python/sglang/srt/managers/image_processors/janus_pro.py
+79
-0
python/sglang/srt/models/deepseek_janus_pro.py
python/sglang/srt/models/deepseek_janus_pro.py
+2127
-0
test/srt/test_vision_openai_server.py
test/srt/test_vision_openai_server.py
+24
-0
No files found.
benchmark/mmmu/eval_utils.py
View file @
01090e8a
...
...
@@ -26,6 +26,7 @@ class EvalArgs:
backend
:
str
=
"engine"
seed
:
int
=
42
split
:
str
=
"validation"
# Default setting to make the benchmark available on A100 for most 7B models
image_pixels_limit
:
int
=
4300000
result_filename
:
str
=
""
prompt_format_file
:
str
=
"prompt_format.yaml"
...
...
@@ -38,6 +39,7 @@ class EvalArgs:
parser
.
add_argument
(
"--result-filename"
,
type
=
str
,
default
=
EvalArgs
.
result_filename
)
parser
.
add_argument
(
"--image-pixels-limit"
,
type
=
int
,
default
=
EvalArgs
.
image_pixels_limit
)
...
...
docs/references/supported_models.md
View file @
01090e8a
...
...
@@ -31,6 +31,7 @@
-
Phi-3 / Phi-4
-
Phi-3-Small
-
IBM Granite 3
-
Janus-Pro-1B / Janus-Pro-7B
## Embedding Models
...
...
python/sglang/lang/chat_template.py
View file @
01090e8a
...
...
@@ -230,6 +230,29 @@ register_chat_template(
)
)
register_chat_template
(
ChatTemplate
(
name
=
"janus-pro"
,
default_system_prompt
=
None
,
role_prefix_and_suffix
=
{
"system"
:
(
""
,
""
,
),
"User"
:
(
"<|User|>"
,
""
,
),
"assistant"
:
(
"<|Assistant|>"
,
"<|end▁of▁sentence|>"
,
),
},
stop_str
=
(
"<|end▁of▁sentence|>"
,),
image_token
=
"<image_placeholder>
\n
"
,
)
)
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template
(
ChatTemplate
(
...
...
@@ -384,6 +407,12 @@ def match_deepseek(model_path: str):
return
get_chat_template
(
"deepseek-v3"
)
@
register_chat_template_matching_function
def
match_deepseek_janus_pro
(
model_path
:
str
):
if
"janus"
in
model_path
.
lower
():
return
get_chat_template
(
"janus-pro"
)
@
register_chat_template_matching_function
def
match_dbrx
(
model_path
:
str
):
if
"dbrx"
in
model_path
.
lower
()
and
"instruct"
in
model_path
.
lower
():
...
...
python/sglang/srt/configs/__init__.py
View file @
01090e8a
from
sglang.srt.configs.chatglm
import
ChatGLMConfig
from
sglang.srt.configs.dbrx
import
DbrxConfig
from
sglang.srt.configs.exaone
import
ExaoneConfig
from
sglang.srt.configs.janus_pro
import
MultiModalityConfig
from
sglang.srt.configs.qwen2_5_vl_config
import
(
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
,
...
...
@@ -12,4 +13,5 @@ __all__ = [
"DbrxConfig"
,
"Qwen2_5_VLConfig"
,
"Qwen2_5_VLVisionConfig"
,
"MultiModalityConfig"
,
]
python/sglang/srt/configs/janus_pro.py
0 → 100644
View file @
01090e8a
# Adapted from:
# https://github.com/deepseek-ai/Janus/tree/main/janus/models
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
numpy
as
np
import
PIL
import
torch
from
PIL.Image
import
Image
from
transformers
import
(
AutoImageProcessor
,
AutoProcessor
,
BaseImageProcessor
,
BatchFeature
,
LlamaConfig
,
LlamaTokenizerFast
,
PretrainedConfig
,
ProcessorMixin
,
)
from
transformers.image_utils
import
to_numpy_array
from
sglang.srt.mm_utils
import
expand2square
class
DictToObject
(
dict
):
def
__init__
(
self
,
dictionary
):
super
(
self
).
__init__
(
dictionary
)
for
key
,
value
in
dictionary
.
items
():
if
isinstance
(
value
,
dict
):
value
=
DictToObject
(
value
)
setattr
(
self
,
key
,
value
)
class
VisionConfig
(
PretrainedConfig
):
model_type
=
"vision"
cls
:
str
=
""
params
=
{}
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
cls
=
kwargs
.
get
(
"cls"
,
""
)
if
not
isinstance
(
self
.
cls
,
str
):
self
.
cls
=
self
.
cls
.
__name__
self
.
params
=
kwargs
.
get
(
"params"
,
{})
class
GenAlignerConfig
(
PretrainedConfig
):
model_type
=
"gen_aligner"
cls
:
str
=
""
params
=
{}
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
cls
=
kwargs
.
get
(
"cls"
,
""
)
if
not
isinstance
(
self
.
cls
,
str
):
self
.
cls
=
self
.
cls
.
__name__
self
.
params
=
kwargs
.
get
(
"params"
,
{})
class
GenHeadConfig
(
PretrainedConfig
):
model_type
=
"gen_head"
cls
:
str
=
""
params
=
{}
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
cls
=
kwargs
.
get
(
"cls"
,
""
)
if
not
isinstance
(
self
.
cls
,
str
):
self
.
cls
=
self
.
cls
.
__name__
self
.
params
=
kwargs
.
get
(
"params"
,
{})
class
AlignerConfig
(
PretrainedConfig
):
model_type
=
"aligner"
cls
:
str
=
""
params
=
{}
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
cls
=
kwargs
.
get
(
"cls"
,
""
)
if
not
isinstance
(
self
.
cls
,
str
):
self
.
cls
=
self
.
cls
.
__name__
self
.
params
=
kwargs
.
get
(
"params"
,
{})
class
GenVisionConfig
(
PretrainedConfig
):
model_type
=
"gen_vision"
cls
:
str
=
""
params
=
{}
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
cls
=
kwargs
.
get
(
"cls"
,
""
)
if
not
isinstance
(
self
.
cls
,
str
):
self
.
cls
=
self
.
cls
.
__name__
self
.
params
=
kwargs
.
get
(
"params"
,
{})
@
dataclass
class
SigLIPVisionCfg
:
width
:
int
=
1152
layers
:
Union
[
Tuple
[
int
,
int
,
int
,
int
],
int
]
=
27
heads
:
int
=
16
patch_size
:
int
=
14
image_size
:
Union
[
Tuple
[
int
,
int
],
int
]
=
336
global_pool
:
str
=
"map"
mlp_ratio
:
float
=
3.7362
class_token
:
bool
=
False
num_classes
:
int
=
0
use_checkpoint
:
bool
=
False
class
MultiModalityConfig
(
PretrainedConfig
):
model_type
=
"multi_modality"
vision_config
:
VisionConfig
aligner_config
:
AlignerConfig
gen_vision_config
:
GenVisionConfig
gen_aligner_config
:
GenAlignerConfig
gen_head_config
:
GenHeadConfig
language_config
:
LlamaConfig
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
vision_config
=
kwargs
.
get
(
"vision_config"
,
{})
self
.
vision_config
=
VisionConfig
(
**
vision_config
)
aligner_config
=
kwargs
.
get
(
"aligner_config"
,
{})
self
.
aligner_config
=
AlignerConfig
(
**
aligner_config
)
gen_vision_config
=
kwargs
.
get
(
"gen_vision_config"
,
{})
self
.
gen_vision_config
=
GenVisionConfig
(
**
gen_vision_config
)
gen_aligner_config
=
kwargs
.
get
(
"gen_aligner_config"
,
{})
self
.
gen_aligner_config
=
GenAlignerConfig
(
**
gen_aligner_config
)
gen_head_config
=
kwargs
.
get
(
"gen_head_config"
,
{})
self
.
gen_head_config
=
GenHeadConfig
(
**
gen_head_config
)
language_config
=
kwargs
.
get
(
"language_config"
,
{})
if
isinstance
(
language_config
,
LlamaConfig
):
self
.
language_config
=
language_config
else
:
self
.
language_config
=
LlamaConfig
(
**
language_config
)
class
VLMImageProcessor
(
BaseImageProcessor
):
model_input_names
=
[
"pixel_values"
]
def
__init__
(
self
,
image_size
:
int
,
min_size
:
int
=
14
,
image_mean
:
Union
[
Tuple
[
float
,
float
,
float
],
List
[
float
]]
=
(
0.48145466
,
0.4578275
,
0.40821073
,
),
image_std
:
Union
[
Tuple
[
float
,
float
,
float
],
List
[
float
]]
=
(
0.26862954
,
0.26130258
,
0.27577711
,
),
rescale_factor
:
float
=
1.0
/
255.0
,
do_normalize
:
bool
=
True
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
image_size
=
image_size
self
.
rescale_factor
=
rescale_factor
self
.
image_mean
=
image_mean
self
.
image_std
=
image_std
self
.
min_size
=
min_size
self
.
do_normalize
=
do_normalize
if
image_mean
is
None
:
self
.
background_color
=
(
127
,
127
,
127
)
else
:
self
.
background_color
=
tuple
([
int
(
x
*
255
)
for
x
in
image_mean
])
def
resize
(
self
,
pil_img
:
Image
)
->
np
.
ndarray
:
"""
Args:
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
Returns:
x (np.ndarray): [3, self.image_size, self.image_size]
"""
width
,
height
=
pil_img
.
size
max_size
=
max
(
width
,
height
)
size
=
[
max
(
int
(
height
/
max_size
*
self
.
image_size
),
self
.
min_size
),
max
(
int
(
width
/
max_size
*
self
.
image_size
),
self
.
min_size
),
]
if
width
<=
0
or
height
<=
0
or
size
[
0
]
<=
0
or
size
[
1
]
<=
0
:
# print(f"orig size = {pil_img.size}, new size = {size}")
raise
ValueError
(
"Invalid size!"
)
def
resize
(
pil_img
,
size
,
interpolation
=
PIL
.
Image
.
Resampling
.
BICUBIC
,
antialias
=
True
):
if
isinstance
(
size
,
int
):
w
,
h
=
pil_img
.
size
if
(
w
<=
h
and
w
==
size
)
or
(
h
<=
w
and
h
==
size
):
return
pil_img
if
w
<
h
:
ow
=
size
oh
=
int
(
size
*
h
/
w
)
else
:
oh
=
size
ow
=
int
(
size
*
w
/
h
)
size
=
(
ow
,
oh
)
else
:
size
=
(
size
[
1
],
size
[
0
])
return
pil_img
.
resize
(
size
,
resample
=
interpolation
,
reducing_gap
=
None
if
antialias
else
3.0
)
pil_img
=
resize
(
pil_img
,
size
,
interpolation
=
PIL
.
Image
.
Resampling
.
BICUBIC
,
antialias
=
True
)
pil_img
=
expand2square
(
pil_img
,
self
.
background_color
)
x
=
to_numpy_array
(
pil_img
)
# [H, W, 3] -> [3, H, W]
x
=
np
.
transpose
(
x
,
(
2
,
0
,
1
))
return
x
def
preprocess
(
self
,
images
,
return_tensors
:
str
=
"pt"
,
**
kwargs
)
->
BatchFeature
:
# resize and pad to [self.image_size, self.image_size]
# then convert from [H, W, 3] to [3, H, W]
if
not
isinstance
(
images
,
list
):
images
=
[
images
]
images
:
List
[
np
.
ndarray
]
=
[
self
.
resize
(
image
)
for
image
in
images
]
images
=
[
image
[:
3
,
...]
for
image
in
images
]
# rescale from [0, 255] -> [0, 1]
images
=
[
self
.
rescale
(
image
=
image
,
scale
=
self
.
rescale_factor
,
input_data_format
=
"channels_first"
,
)
for
image
in
images
]
# normalize
if
self
.
do_normalize
:
images
=
[
self
.
normalize
(
image
=
image
,
mean
=
self
.
image_mean
,
std
=
self
.
image_std
,
input_data_format
=
"channels_first"
,
)
for
image
in
images
]
data
=
{
"pixel_values"
:
images
}
return
BatchFeature
(
data
=
data
,
tensor_type
=
return_tensors
)
@
property
def
default_shape
(
self
):
return
[
3
,
self
.
image_size
,
self
.
image_size
]
class
DictOutput
(
object
):
def
keys
(
self
):
return
self
.
__dict__
.
keys
()
def
__getitem__
(
self
,
item
):
return
self
.
__dict__
[
item
]
def
__setitem__
(
self
,
key
,
value
):
self
.
__dict__
[
key
]
=
value
@
dataclass
class
VLChatProcessorOutput
(
DictOutput
):
sft_format
:
str
input_ids
:
torch
.
Tensor
pixel_values
:
torch
.
Tensor
num_image_tokens
:
torch
.
IntTensor
def
__len__
(
self
):
return
len
(
self
.
input_ids
)
@
dataclass
class
BatchedVLChatProcessorOutput
(
DictOutput
):
sft_format
:
List
[
str
]
input_ids
:
torch
.
Tensor
pixel_values
:
torch
.
Tensor
attention_mask
:
torch
.
Tensor
images_seq_mask
:
torch
.
BoolTensor
images_emb_mask
:
torch
.
BoolTensor
# FIXME: had to place Official Processor here, since image_processor module would not be imported in all threads,
# hence AutoProcessor registration would not be affective in some cases
class
VLChatProcessor
(
ProcessorMixin
):
image_processor_class
=
"AutoImageProcessor"
tokenizer_class
=
(
"LlamaTokenizer"
,
"LlamaTokenizerFast"
)
attributes
=
[
"image_processor"
,
"tokenizer"
]
def
__init__
(
self
,
image_processor
:
VLMImageProcessor
,
tokenizer
:
LlamaTokenizerFast
,
image_tag
:
str
=
"<image_placeholder>"
,
image_start_tag
:
str
=
"<begin_of_image>"
,
image_end_tag
:
str
=
"<end_of_image>"
,
pad_tag
:
str
=
"<|▁pad▁|>"
,
num_image_tokens
:
int
=
576
,
add_special_token
:
bool
=
False
,
sft_format
:
str
=
"deepseek"
,
mask_prompt
:
bool
=
True
,
ignore_id
:
int
=
-
100
,
**
kwargs
,
):
self
.
image_processor
=
image_processor
self
.
tokenizer
=
tokenizer
image_id
=
self
.
tokenizer
.
vocab
.
get
(
image_tag
)
if
image_id
is
None
:
special_tokens
=
[
image_tag
]
special_tokens_dict
=
{
"additional_special_tokens"
:
special_tokens
}
self
.
tokenizer
.
add_special_tokens
(
special_tokens_dict
)
# print(f"Add image tag = {image_tag} to the tokenizer")
self
.
image_tag
=
image_tag
self
.
image_start_tag
=
image_start_tag
self
.
image_end_tag
=
image_end_tag
self
.
pad_tag
=
pad_tag
self
.
num_image_tokens
=
num_image_tokens
self
.
add_special_token
=
add_special_token
self
.
sft_format
=
sft_format
self
.
ignore_id
=
ignore_id
super
().
__init__
(
image_processor
,
tokenizer
,
**
kwargs
,
)
@
property
def
image_token
(
self
):
return
self
.
image_tag
@
property
def
image_id
(
self
)
->
int
:
image_id
=
self
.
tokenizer
.
vocab
.
get
(
self
.
image_tag
)
return
image_id
@
property
def
image_start_id
(
self
):
image_start_id
=
self
.
tokenizer
.
vocab
.
get
(
self
.
image_start_tag
)
return
image_start_id
@
property
def
image_end_id
(
self
):
image_end_id
=
self
.
tokenizer
.
vocab
.
get
(
self
.
image_end_tag
)
return
image_end_id
@
property
def
image_start_token
(
self
):
return
self
.
image_start_tag
@
property
def
image_end_token
(
self
):
return
self
.
image_end_tag
@
property
def
pad_id
(
self
):
pad_id
=
self
.
tokenizer
.
vocab
.
get
(
self
.
pad_tag
)
return
pad_id
def
add_image_token
(
self
,
image_indices
:
List
[
int
],
input_ids
:
torch
.
LongTensor
,
):
"""
Args:
image_indices (List[int]): [index_0, index_1, ..., index_j]
input_ids (torch.LongTensor): [N]
Returns:
input_ids (torch.LongTensor): [N + image tokens]
num_image_tokens (torch.IntTensor): [n_images]
"""
input_slices
=
[]
start
=
0
for
index
in
image_indices
:
if
self
.
add_special_token
:
end
=
index
+
1
else
:
end
=
index
# original text tokens
input_slices
.
append
(
input_ids
[
start
:
end
])
# add boi, image tokens, eoi and set the mask as False
input_slices
.
append
(
self
.
image_start_id
*
torch
.
ones
((
1
),
dtype
=
torch
.
long
))
input_slices
.
append
(
self
.
image_id
*
torch
.
ones
((
self
.
num_image_tokens
,),
dtype
=
torch
.
long
)
)
input_slices
.
append
(
self
.
image_end_id
*
torch
.
ones
((
1
),
dtype
=
torch
.
long
))
start
=
index
+
1
# the left part
input_slices
.
append
(
input_ids
[
start
:])
# concat all slices
input_ids
=
torch
.
cat
(
input_slices
,
dim
=
0
)
num_image_tokens
=
torch
.
IntTensor
([
self
.
num_image_tokens
]
*
len
(
image_indices
))
return
input_ids
,
num_image_tokens
def
process_one
(
self
,
prompt
:
str
=
None
,
images
:
List
[
Image
]
=
None
,
**
kwargs
,
):
"""
Args:
prompt (str): the formatted prompt;
images (List[ImageType]): the list of images;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
sft_format
=
prompt
# tokenize
input_ids
=
self
.
tokenizer
.
encode
(
sft_format
)
input_ids
=
torch
.
LongTensor
(
input_ids
)
# add image tokens to the input_ids
image_token_mask
:
torch
.
Tensor
=
(
input_ids
==
self
.
image_id
).
to
(
torch
.
bool
)
image_indices
=
image_token_mask
.
nonzero
()
input_ids
,
num_image_tokens
=
self
.
add_image_token
(
image_indices
=
image_indices
,
input_ids
=
input_ids
,
)
# load images
images_outputs
=
self
.
image_processor
(
images
,
return_tensors
=
"pt"
)
prepare
=
VLChatProcessorOutput
(
sft_format
=
sft_format
,
input_ids
=
input_ids
,
pixel_values
=
images_outputs
.
pixel_values
,
num_image_tokens
=
num_image_tokens
,
)
return
prepare
def
__call__
(
self
,
*
,
prompt
:
str
=
None
,
conversations
:
List
[
Dict
[
str
,
str
]]
=
None
,
images
:
List
[
Image
]
=
None
,
force_batchify
:
bool
=
True
,
**
kwargs
,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
force_batchify (bool): force batchify the inputs;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
prepare
=
self
.
process_one
(
prompt
=
prompt
,
conversations
=
conversations
,
images
=
images
)
if
force_batchify
:
prepare
=
self
.
batchify
([
prepare
])
return
prepare
def
batchify
(
self
,
prepare_list
:
List
[
VLChatProcessorOutput
]
)
->
BatchedVLChatProcessorOutput
:
"""
Preprocesses the inputs for multimodal inference.
Args:
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
Returns:
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
"""
batch_size
=
len
(
prepare_list
)
sft_format
=
[]
n_images
=
[]
seq_lens
=
[]
for
prepare
in
prepare_list
:
n_images
.
append
(
len
(
prepare
.
num_image_tokens
))
seq_lens
.
append
(
len
(
prepare
))
input_token_max_len
=
max
(
seq_lens
)
max_n_images
=
max
(
1
,
max
(
n_images
))
batched_input_ids
=
torch
.
full
(
(
batch_size
,
input_token_max_len
),
self
.
pad_id
).
long
()
# FIXME
batched_attention_mask
=
torch
.
zeros
((
batch_size
,
input_token_max_len
)).
long
()
batched_pixel_values
=
torch
.
zeros
(
(
batch_size
,
max_n_images
,
*
self
.
image_processor
.
default_shape
)
).
float
()
batched_images_seq_mask
=
torch
.
zeros
((
batch_size
,
input_token_max_len
)).
bool
()
batched_images_emb_mask
=
torch
.
zeros
(
(
batch_size
,
max_n_images
,
self
.
num_image_tokens
)
).
bool
()
for
i
,
prepare
in
enumerate
(
prepare_list
):
input_ids
=
prepare
.
input_ids
seq_len
=
len
(
prepare
)
n_image
=
len
(
prepare
.
num_image_tokens
)
# left-padding
batched_attention_mask
[
i
,
-
seq_len
:]
=
1
batched_input_ids
[
i
,
-
seq_len
:]
=
torch
.
LongTensor
(
input_ids
)
batched_images_seq_mask
[
i
,
-
seq_len
:]
=
input_ids
==
self
.
image_id
if
n_image
>
0
:
batched_pixel_values
[
i
,
:
n_image
]
=
prepare
.
pixel_values
for
j
,
n_image_tokens
in
enumerate
(
prepare
.
num_image_tokens
):
batched_images_emb_mask
[
i
,
j
,
:
n_image_tokens
]
=
True
sft_format
.
append
(
prepare
.
sft_format
)
batched_prepares
=
BatchedVLChatProcessorOutput
(
input_ids
=
batched_input_ids
,
attention_mask
=
batched_attention_mask
,
pixel_values
=
batched_pixel_values
,
images_seq_mask
=
batched_images_seq_mask
,
images_emb_mask
=
batched_images_emb_mask
,
sft_format
=
sft_format
,
)
return
batched_prepares
class
VLMImageProcessorConfig
(
PretrainedConfig
):
model_type
=
"deepseek_vlm"
image_size
:
int
min_size
:
int
image_mean
:
Union
[
Tuple
[
float
,
float
,
float
],
List
[
float
]]
image_std
:
Union
[
Tuple
[
float
,
float
,
float
],
List
[
float
]]
rescale_factor
:
float
do_normalize
:
bool
def
__init__
(
self
,
image_size
:
int
,
min_size
:
int
=
14
,
image_mean
:
Union
[
Tuple
[
float
,
float
,
float
],
List
[
float
]]
=
(
0.48145466
,
0.4578275
,
0.40821073
,
),
image_std
:
Union
[
Tuple
[
float
,
float
,
float
],
List
[
float
]]
=
(
0.26862954
,
0.26130258
,
0.27577711
,
),
rescale_factor
:
float
=
1.0
/
255.0
,
do_normalize
:
bool
=
True
,
**
kwargs
,
):
self
.
image_size
=
image_size
self
.
min_size
=
min_size
self
.
image_mean
=
image_mean
self
.
image_std
=
image_std
self
.
rescale_factor
=
rescale_factor
self
.
do_normalize
=
do_normalize
super
().
__init__
(
**
kwargs
)
AutoProcessor
.
register
(
MultiModalityConfig
,
VLChatProcessor
,
exist_ok
=
True
)
AutoImageProcessor
.
register
(
VLMImageProcessorConfig
,
None
,
VLMImageProcessor
,
None
)
python/sglang/srt/configs/model_config.py
View file @
01090e8a
...
...
@@ -408,7 +408,7 @@ def _get_and_verify_dtype(
def
is_generation_model
(
model_architectures
:
List
[
str
],
is_embedding
:
bool
=
False
):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 1. Check the model architectu
r
e
# 2. check the `is_embedding` server args
if
(
...
...
@@ -424,18 +424,25 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
return
not
is_embedding
multimodal_model_archs
=
[
"LlavaLlamaForCausalLM"
,
"LlavaQwenForCausalLM"
,
"LlavaMistralForCausalLM"
,
"LlavaVidForCausalLM"
,
"Grok1VForCausalLM"
,
"Grok1AForCausalLM"
,
"MllamaForConditionalGeneration"
,
"Qwen2VLForConditionalGeneration"
,
"Qwen2_5_VLForConditionalGeneration"
,
"MiniCPMV"
,
"MultiModalityCausalLM"
,
]
def
is_multimodal_model
(
model_architectures
:
List
[
str
]):
if
(
"LlavaLlamaForCausalLM"
in
model_architectures
or
"LlavaQwenForCausalLM"
in
model_architectures
or
"LlavaMistralForCausalLM"
in
model_architectures
or
"LlavaVidForCausalLM"
in
model_architectures
or
"Grok1VForCausalLM"
in
model_architectures
or
"Grok1AForCausalLM"
in
model_architectures
or
"MllamaForConditionalGeneration"
in
model_architectures
or
"Qwen2VLForConditionalGeneration"
in
model_architectures
or
"Qwen2_5_VLForConditionalGeneration"
in
model_architectures
or
"MiniCPMV"
in
model_architectures
if
any
(
multi_model_arch
in
model_architectures
for
multi_model_arch
in
multimodal_model_archs
):
return
True
else
:
...
...
python/sglang/srt/conversation.py
View file @
01090e8a
...
...
@@ -631,3 +631,18 @@ register_conv_template(
image_token
=
"(<image>./</image>)"
,
)
)
# Reference: https://github.com/deepseek-ai/Janus?tab=readme-ov-file#janus-pro
register_conv_template
(
Conversation
(
name
=
"janus-pro"
,
system_message
=
"You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language"
,
system_template
=
"{system_message}."
,
roles
=
(
"User"
,
"Assistant"
),
sep
=
"
\n\n
"
,
sep2
=
"<|end▁of▁sentence|>"
,
sep_style
=
SeparatorStyle
.
ADD_COLON_TWO
,
stop_str
=
[
"<|User|>"
,
"<|end▁of▁sentence|>"
],
image_token
=
"<image_placeholder>"
,
)
)
python/sglang/srt/hf_transformers_utils.py
View file @
01090e8a
...
...
@@ -30,13 +30,20 @@ from transformers import (
)
from
transformers.models.auto.modeling_auto
import
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from
sglang.srt.configs
import
ChatGLMConfig
,
DbrxConfig
,
ExaoneConfig
,
Qwen2_5_VLConfig
from
sglang.srt.configs
import
(
ChatGLMConfig
,
DbrxConfig
,
ExaoneConfig
,
MultiModalityConfig
,
Qwen2_5_VLConfig
,
)
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
DbrxConfig
.
model_type
:
DbrxConfig
,
ExaoneConfig
.
model_type
:
ExaoneConfig
,
Qwen2_5_VLConfig
.
model_type
:
Qwen2_5_VLConfig
,
MultiModalityConfig
.
model_type
:
MultiModalityConfig
,
}
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
...
...
@@ -67,6 +74,13 @@ def get_config(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
# FIXME: Pour contents of janus-pro's langauge_config to first-level
if
isinstance
(
model
,
str
)
and
model
.
lower
().
startswith
(
"deepseek-ai/janus-pro"
):
assert
hasattr
(
config
,
"language_config"
)
for
key
,
val
in
config
.
language_config
.
__dict__
.
items
():
setattr
(
config
,
key
,
val
)
setattr
(
config
,
"architectures"
,
[
"MultiModalityCausalLM"
])
if
config
.
model_type
in
_CONFIG_REGISTRY
:
config_class
=
_CONFIG_REGISTRY
[
config
.
model_type
]
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
)
...
...
python/sglang/srt/layers/attention/vision.py
View file @
01090e8a
...
...
@@ -6,7 +6,7 @@ from typing import Optional, Tuple
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
from
sglang.srt.distributed
import
parallel_state
from
sglang.srt.distributed
import
utils
as
dist_utils
...
...
python/sglang/srt/managers/image_processors/base_image_processor.py
View file @
01090e8a
...
...
@@ -13,6 +13,7 @@ from PIL import Image
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
load_image
from
sglang.utils
import
logger
global
global_processor
...
...
@@ -22,6 +23,13 @@ def get_global_processor():
return
global_processor
def
init_global_processor
(
sglang_image_processor
,
server_args
:
ServerArgs
):
"""Init the global processor for multi-modal models."""
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
global_processor
=
sglang_image_processor
.
_build_processor
(
server_args
=
server_args
)
@
dataclasses
.
dataclass
class
BaseImageProcessorOutput
:
image_hashes
:
list
[
int
]
...
...
@@ -119,6 +127,11 @@ class BaseImageProcessor(ABC):
)
->
BaseImageProcessorOutput
:
"""
Each frame of video/image will be replaced by a single image token
Args:
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
image_hashes
,
image_sizes
=
[],
[]
all_frames
=
[]
...
...
@@ -133,7 +146,7 @@ class BaseImageProcessor(ABC):
if
return_text
:
text_parts
=
input_text
.
split
(
image_token
)
#
roughly calculate the max number of frames under the max_req_input_len limit
#
TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES
=
30
estimated_frames_list
=
self
.
get_estimated_frames_list
(
image_data
=
image_data
)
total_frame_count
=
sum
(
estimated_frames_list
)
...
...
python/sglang/srt/managers/image_processors/janus_pro.py
0 → 100644
View file @
01090e8a
import
asyncio
from
typing
import
List
,
Union
from
sglang.srt.managers.image_processors.base_image_processor
import
(
BaseImageProcessor
as
SGLangBaseImageProcessor
,
)
from
sglang.srt.managers.image_processors.base_image_processor
import
(
get_global_processor
,
)
from
sglang.srt.models.deepseek_janus_pro
import
MultiModalityCausalLM
class
JanusProProcessor
(
SGLangBaseImageProcessor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_images_task
(
images
,
input_text
):
processor
=
get_global_processor
()
result
=
processor
.
__call__
(
prompt
=
input_text
,
images
=
images
,
return_tensors
=
"pt"
)
return
{
"input_ids"
:
result
[
"input_ids"
],
"pixel_values"
:
result
[
"pixel_values"
],
"images_emb_mask"
:
result
[
"images_emb_mask"
],
"im_start_id"
:
processor
.
image_start_id
,
"im_end_id"
:
processor
.
image_end_id
,
"im_token_id"
:
processor
.
image_id
,
}
async
def
_process_images
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
JanusProProcessor
.
_process_images_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
=
images
,
text
=
input_text
,
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_ids
,
request_obj
,
max_req_input_len
,
**
kwargs
,
):
if
not
image_data
:
return
None
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
base_out
=
self
.
load_images
(
input_ids
,
image_data
,
"<image_placeholder>"
,
max_req_input_len
)
images
=
base_out
.
all_frames
res
=
await
self
.
_process_images
(
images
=
images
,
input_text
=
base_out
.
input_text
)
return
{
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
res
[
"pixel_values"
],
"images_emb_mask"
:
res
[
"images_emb_mask"
],
"image_hashes"
:
base_out
.
image_hashes
,
"im_start_id"
:
res
[
"im_start_id"
],
"im_end_id"
:
res
[
"im_end_id"
],
"im_token_id"
:
res
[
"im_token_id"
],
}
ImageProcessorMapping
=
{
MultiModalityCausalLM
:
JanusProProcessor
}
python/sglang/srt/models/deepseek_janus_pro.py
0 → 100644
View file @
01090e8a
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Copied and Adapted from:
# https://github.com/deepseek-ai/Janus
import
collections
import
math
import
os
from
dataclasses
import
field
from
enum
import
Enum
from
functools
import
partial
from
itertools
import
repeat
from
typing
import
(
Callable
,
Final
,
Iterable
,
Literal
,
Optional
,
Sequence
,
Set
,
Tuple
,
Type
,
Union
,
)
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch
import
Tensor
,
_assert
,
nn
from
torch.nn.init
import
trunc_normal_
from
transformers
import
AutoModel
,
PreTrainedModel
from
sglang.srt.configs.janus_pro
import
*
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.managers.multi_modality_padding
import
(
MultiModalityDataPaddingPatternTokenPairs
,
)
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
from
sglang.utils
import
logger
#################################################################################
# VQ Model Configs #
#################################################################################
# Copied from:
# https://github.com/deepseek-ai/Janus/tree/main/janus/models/vq_model.py
@
dataclass
class
ModelArgs
:
codebook_size
:
int
=
16384
codebook_embed_dim
:
int
=
8
codebook_l2_norm
:
bool
=
True
codebook_show_usage
:
bool
=
True
commit_loss_beta
:
float
=
0.25
entropy_loss_ratio
:
float
=
0.0
encoder_ch_mult
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
1
,
1
,
2
,
2
,
4
])
decoder_ch_mult
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
1
,
1
,
2
,
2
,
4
])
z_channels
:
int
=
256
dropout_p
:
float
=
0.0
def
named_apply
(
fn
:
Callable
,
module
:
nn
.
Module
,
name
=
""
,
depth_first
:
bool
=
True
,
include_root
:
bool
=
False
,
)
->
nn
.
Module
:
if
not
depth_first
and
include_root
:
fn
(
module
=
module
,
name
=
name
)
for
child_name
,
child_module
in
module
.
named_children
():
child_name
=
"."
.
join
((
name
,
child_name
))
if
name
else
child_name
named_apply
(
fn
=
fn
,
module
=
child_module
,
name
=
child_name
,
depth_first
=
depth_first
,
include_root
=
True
,
)
if
depth_first
and
include_root
:
fn
(
module
=
module
,
name
=
name
)
return
module
def
VQ_16
(
**
kwargs
):
return
VQModel
(
ModelArgs
(
encoder_ch_mult
=
[
1
,
1
,
2
,
2
,
4
],
decoder_ch_mult
=
[
1
,
1
,
2
,
2
,
4
],
**
kwargs
)
)
VQ_models
=
{
"VQ-16"
:
VQ_16
}
import
collections.abc
# From PyTorch internals
def
_ntuple
(
n
):
def
parse
(
x
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
)
and
not
isinstance
(
x
,
str
):
return
tuple
(
x
)
return
tuple
(
repeat
(
x
,
n
))
return
parse
def
_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def
norm_cdf
(
x
):
# Computes standard normal cumulative distribution function
return
(
1.0
+
math
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
/
2.0
if
(
mean
<
a
-
2
*
std
)
or
(
mean
>
b
+
2
*
std
):
logger
.
warn
(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
,
stacklevel
=
2
,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l
=
norm_cdf
((
a
-
mean
)
/
std
)
u
=
norm_cdf
((
b
-
mean
)
/
std
)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor
.
uniform_
(
2
*
l
-
1
,
2
*
u
-
1
)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
if
tensor
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]:
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
og_dtype
=
tensor
.
dtype
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
.
erfinv_
()
tensor
=
tensor
.
to
(
og_dtype
)
else
:
tensor
.
erfinv_
()
# Transform to proper mean, std
tensor
.
mul_
(
std
*
math
.
sqrt
(
2.0
))
tensor
.
add_
(
mean
)
# Clamp to ensure it's in the proper range
if
tensor
.
dtype
==
torch
.
float16
:
# The `clamp_` op is not (yet?) defined in float16+cpu
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
.
clamp_
(
min
=
a
,
max
=
b
)
else
:
tensor
.
clamp_
(
min
=
a
,
max
=
b
)
def
trunc_normal_tf_
(
tensor
:
torch
.
Tensor
,
mean
:
float
=
0.0
,
std
:
float
=
1.0
,
a
:
float
=
-
2.0
,
b
:
float
=
2.0
,
):
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`
\\
mathcal{N}(
\t
ext{mean},
\t
ext{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a
\\
leq
\t
ext{mean}
\\
leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsquently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with
torch
.
no_grad
():
_trunc_normal_
(
tensor
,
0
,
1.0
,
a
,
b
)
tensor
.
mul_
(
std
).
add_
(
mean
)
to_2tuple
=
_ntuple
(
2
)
class
Format
(
str
,
Enum
):
NCHW
=
"NCHW"
NHWC
=
"NHWC"
NCL
=
"NCL"
NLC
=
"NLC"
def
nchw_to
(
x
:
torch
.
Tensor
,
fmt
:
Format
):
if
fmt
==
Format
.
NHWC
:
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
elif
fmt
==
Format
.
NLC
:
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
elif
fmt
==
Format
.
NCL
:
x
=
x
.
flatten
(
2
)
return
x
def
resample_patch_embed
(
patch_embed
,
new_size
:
List
[
int
],
interpolation
:
str
=
"bicubic"
,
antialias
:
bool
=
True
,
verbose
:
bool
=
False
,
):
"""Resample the weights of the patch embedding kernel to target resolution.
We resample the patch embedding kernel by approximately inverting the effect
of patch resizing.
Code based on:
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
With this resizing, we can for example load a B/8 filter into a B/16 model
and, on 2x larger input image, the result will match.
Args:
patch_embed: original parameter to be resized.
new_size (tuple(int, int): target shape (height, width)-only.
interpolation (str): interpolation for resize
antialias (bool): use anti-aliasing filter in resize
verbose (bool): log operation
Returns:
Resized patch embedding kernel.
"""
import
numpy
as
np
try
:
from
torch
import
vmap
except
ImportError
:
from
functorch
import
vmap
assert
len
(
patch_embed
.
shape
)
==
4
,
"Four dimensions expected"
assert
len
(
new_size
)
==
2
,
"New shape should only be hw"
old_size
=
patch_embed
.
shape
[
-
2
:]
if
tuple
(
old_size
)
==
tuple
(
new_size
):
return
patch_embed
if
verbose
:
logger
.
info
(
f
"Resize patch embedding
{
patch_embed
.
shape
}
to
{
new_size
}
, w/
{
interpolation
}
interpolation."
)
def
resize
(
x_np
,
_new_size
):
x_tf
=
torch
.
Tensor
(
x_np
)[
None
,
None
,
...]
x_upsampled
=
F
.
interpolate
(
x_tf
,
size
=
_new_size
,
mode
=
interpolation
,
antialias
=
antialias
)[
0
,
0
,
...].
numpy
()
return
x_upsampled
def
get_resize_mat
(
_old_size
,
_new_size
):
mat
=
[]
for
i
in
range
(
np
.
prod
(
_old_size
)):
basis_vec
=
np
.
zeros
(
_old_size
)
basis_vec
[
np
.
unravel_index
(
i
,
_old_size
)]
=
1.0
mat
.
append
(
resize
(
basis_vec
,
_new_size
).
reshape
(
-
1
))
return
np
.
stack
(
mat
).
T
resize_mat
=
get_resize_mat
(
old_size
,
new_size
)
resize_mat_pinv
=
torch
.
tensor
(
np
.
linalg
.
pinv
(
resize_mat
.
T
),
device
=
patch_embed
.
device
)
def
resample_kernel
(
kernel
):
resampled_kernel
=
resize_mat_pinv
@
kernel
.
reshape
(
-
1
)
return
resampled_kernel
.
reshape
(
new_size
)
v_resample_kernel
=
vmap
(
vmap
(
resample_kernel
,
0
,
0
),
1
,
1
)
orig_dtype
=
patch_embed
.
dtype
patch_embed
=
patch_embed
.
float
()
patch_embed
=
v_resample_kernel
(
patch_embed
)
patch_embed
=
patch_embed
.
to
(
orig_dtype
)
return
patch_embed
# Copied from:
# https://github.com/deepseek-ai/Janus/tree/main/janus/models/siglip_vit.py
class
PatchEmbed
(
nn
.
Module
):
"""2D Image to Patch Embedding"""
output_fmt
:
Format
dynamic_img_pad
:
torch
.
jit
.
Final
[
bool
]
def
__init__
(
self
,
img_size
:
Optional
[
int
]
=
224
,
patch_size
:
int
=
16
,
in_chans
:
int
=
3
,
embed_dim
:
int
=
768
,
norm_layer
:
Optional
[
Callable
]
=
None
,
flatten
:
bool
=
True
,
output_fmt
:
Optional
[
str
]
=
None
,
bias
:
bool
=
True
,
strict_img_size
:
bool
=
True
,
dynamic_img_pad
:
bool
=
False
,
):
super
().
__init__
()
self
.
patch_size
=
tuple
(
to_2tuple
(
patch_size
))
self
.
img_size
,
self
.
grid_size
,
self
.
num_patches
=
self
.
_init_img_size
(
img_size
)
if
output_fmt
is
not
None
:
self
.
flatten
=
False
self
.
output_fmt
=
Format
(
output_fmt
)
else
:
# flatten spatial dim and transpose to channels last, kept for bwd compat
self
.
flatten
=
flatten
self
.
output_fmt
=
Format
.
NCHW
self
.
strict_img_size
=
strict_img_size
self
.
dynamic_img_pad
=
dynamic_img_pad
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
bias
=
bias
)
self
.
norm
=
norm_layer
(
embed_dim
)
if
norm_layer
else
nn
.
Identity
()
def
_init_img_size
(
self
,
img_size
:
Union
[
int
,
Tuple
[
int
,
int
]]):
assert
self
.
patch_size
if
img_size
is
None
:
return
None
,
None
,
None
img_size
=
to_2tuple
(
img_size
)
grid_size
=
tuple
([
s
//
p
for
s
,
p
in
zip
(
img_size
,
self
.
patch_size
)])
num_patches
=
grid_size
[
0
]
*
grid_size
[
1
]
return
img_size
,
grid_size
,
num_patches
def
set_input_size
(
self
,
img_size
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
,
patch_size
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
,
):
new_patch_size
=
None
if
patch_size
is
not
None
:
new_patch_size
=
to_2tuple
(
patch_size
)
if
new_patch_size
is
not
None
and
new_patch_size
!=
self
.
patch_size
:
with
torch
.
no_grad
():
new_proj
=
nn
.
Conv2d
(
self
.
proj
.
in_channels
,
self
.
proj
.
out_channels
,
kernel_size
=
new_patch_size
,
stride
=
new_patch_size
,
bias
=
self
.
proj
.
bias
is
not
None
,
)
new_proj
.
weight
.
copy_
(
resample_patch_embed
(
self
.
proj
.
weight
,
new_patch_size
,
verbose
=
True
)
)
if
self
.
proj
.
bias
is
not
None
:
new_proj
.
bias
.
copy_
(
self
.
proj
.
bias
)
self
.
proj
=
new_proj
self
.
patch_size
=
new_patch_size
img_size
=
img_size
or
self
.
img_size
if
img_size
!=
self
.
img_size
or
new_patch_size
is
not
None
:
self
.
img_size
,
self
.
grid_size
,
self
.
num_patches
=
self
.
_init_img_size
(
img_size
)
def
feat_ratio
(
self
,
as_scalar
=
True
)
->
Union
[
Tuple
[
int
,
int
],
int
]:
if
as_scalar
:
return
max
(
self
.
patch_size
)
else
:
return
self
.
patch_size
def
dynamic_feat_size
(
self
,
img_size
:
Tuple
[
int
,
int
])
->
Tuple
[
int
,
int
]:
"""Get grid (feature) size for given image size taking account of dynamic padding.
NOTE: must be torchscript compatible so using fixed tuple indexing
"""
if
self
.
dynamic_img_pad
:
return
math
.
ceil
(
img_size
[
0
]
/
self
.
patch_size
[
0
]),
math
.
ceil
(
img_size
[
1
]
/
self
.
patch_size
[
1
]
)
else
:
return
img_size
[
0
]
//
self
.
patch_size
[
0
],
img_size
[
1
]
//
self
.
patch_size
[
1
]
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
if
self
.
img_size
is
not
None
:
if
self
.
strict_img_size
:
_assert
(
H
==
self
.
img_size
[
0
],
f
"Input height (
{
H
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
)."
,
)
_assert
(
W
==
self
.
img_size
[
1
],
f
"Input width (
{
W
}
) doesn't match model (
{
self
.
img_size
[
1
]
}
)."
,
)
elif
not
self
.
dynamic_img_pad
:
_assert
(
H
%
self
.
patch_size
[
0
]
==
0
,
f
"Input height (
{
H
}
) should be divisible by patch size (
{
self
.
patch_size
[
0
]
}
)."
,
)
_assert
(
W
%
self
.
patch_size
[
1
]
==
0
,
f
"Input width (
{
W
}
) should be divisible by patch size (
{
self
.
patch_size
[
1
]
}
)."
,
)
if
self
.
dynamic_img_pad
:
pad_h
=
(
self
.
patch_size
[
0
]
-
H
%
self
.
patch_size
[
0
])
%
self
.
patch_size
[
0
]
pad_w
=
(
self
.
patch_size
[
1
]
-
W
%
self
.
patch_size
[
1
])
%
self
.
patch_size
[
1
]
x
=
F
.
pad
(
x
,
(
0
,
pad_w
,
0
,
pad_h
))
x
=
self
.
proj
(
x
)
if
self
.
flatten
:
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
# NCHW -> NLC
elif
self
.
output_fmt
!=
Format
.
NCHW
:
x
=
nchw_to
(
x
,
self
.
output_fmt
)
x
=
self
.
norm
(
x
)
return
x
class
Mlp
(
nn
.
Module
):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks
NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
"""
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
norm_layer
=
None
,
bias
=
True
,
drop
=
0.0
,
use_conv
=
False
,
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
bias
=
to_2tuple
(
bias
)
drop_probs
=
to_2tuple
(
drop
)
linear_layer
=
partial
(
nn
.
Conv2d
,
kernel_size
=
1
)
if
use_conv
else
nn
.
Linear
self
.
fc1
=
linear_layer
(
in_features
,
hidden_features
,
bias
=
bias
[
0
])
self
.
act
=
act_layer
()
self
.
drop1
=
nn
.
Dropout
(
drop_probs
[
0
])
self
.
norm
=
(
norm_layer
(
hidden_features
)
if
norm_layer
is
not
None
else
nn
.
Identity
()
)
self
.
fc2
=
linear_layer
(
hidden_features
,
out_features
,
bias
=
bias
[
1
])
self
.
drop2
=
nn
.
Dropout
(
drop_probs
[
1
])
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop1
(
x
)
x
=
self
.
norm
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop2
(
x
)
return
x
def
drop_path
(
x
,
drop_prob
:
float
=
0.0
,
training
:
bool
=
False
,
scale_by_keep
:
bool
=
True
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if
drop_prob
==
0.0
or
not
training
:
return
x
keep_prob
=
1
-
drop_prob
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
# work with diff dim tensors, not just 2D ConvNets
random_tensor
=
x
.
new_empty
(
shape
).
bernoulli_
(
keep_prob
)
if
keep_prob
>
0.0
and
scale_by_keep
:
random_tensor
.
div_
(
keep_prob
)
return
x
*
random_tensor
class
DropPath
(
nn
.
Module
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def
__init__
(
self
,
drop_prob
:
float
=
0.0
,
scale_by_keep
:
bool
=
True
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
self
.
scale_by_keep
=
scale_by_keep
def
forward
(
self
,
x
):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
,
self
.
scale_by_keep
)
def
extra_repr
(
self
):
return
f
"drop_prob=
{
round
(
self
.
drop_prob
,
3
):
0.3
f
}
"
class
VisionTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
num_heads
:
int
,
mlp_ratio
:
float
=
4.0
,
qkv_bias
:
bool
=
False
,
qk_norm
:
bool
=
False
,
proj_drop
:
float
=
0.0
,
attn_drop
:
float
=
0.0
,
init_values
:
Optional
[
float
]
=
None
,
drop_path
:
float
=
0.0
,
act_layer
:
nn
.
Module
=
nn
.
GELU
,
norm_layer
:
nn
.
Module
=
nn
.
LayerNorm
,
mlp_layer
:
nn
.
Module
=
Mlp
,
)
->
None
:
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
VisionAttention
(
embed_dim
=
dim
,
num_heads
=
num_heads
,
projection_size
=
dim
,
use_qkv_parallel
=
True
,
use_context_forward
=
False
,
softmax_in_single_precision
=
False
,
dropout
=
attn_drop
,
)
self
.
ls1
=
(
LayerScale
(
dim
,
init_values
=
init_values
)
if
init_values
else
nn
.
Identity
()
)
self
.
drop_path1
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
self
.
mlp
=
mlp_layer
(
in_features
=
dim
,
hidden_features
=
int
(
dim
*
mlp_ratio
),
act_layer
=
act_layer
,
drop
=
proj_drop
,
)
self
.
ls2
=
(
LayerScale
(
dim
,
init_values
=
init_values
)
if
init_values
else
nn
.
Identity
()
)
self
.
drop_path2
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
x
+
self
.
drop_path1
(
self
.
ls1
(
self
.
attn
(
self
.
norm1
(
x
))))
x
=
x
+
self
.
drop_path2
(
self
.
ls2
(
self
.
mlp
(
self
.
norm2
(
x
))))
return
x
LayerType
=
Union
[
str
,
Callable
,
Type
[
torch
.
nn
.
Module
]]
class
PatchDropout
(
nn
.
Module
):
"""
https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
"""
return_indices
:
torch
.
jit
.
Final
[
bool
]
def
__init__
(
self
,
prob
:
float
=
0.5
,
num_prefix_tokens
:
int
=
1
,
ordered
:
bool
=
False
,
return_indices
:
bool
=
False
,
):
super
().
__init__
()
assert
0
<=
prob
<
1.0
self
.
prob
=
prob
self
.
num_prefix_tokens
=
(
num_prefix_tokens
# exclude CLS token (or other prefix tokens)
)
self
.
ordered
=
ordered
self
.
return_indices
=
return_indices
def
forward
(
self
,
x
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]]:
if
not
self
.
training
or
self
.
prob
==
0.0
:
if
self
.
return_indices
:
return
x
,
None
return
x
if
self
.
num_prefix_tokens
:
prefix_tokens
,
x
=
(
x
[:,
:
self
.
num_prefix_tokens
],
x
[:,
self
.
num_prefix_tokens
:],
)
else
:
prefix_tokens
=
None
B
=
x
.
shape
[
0
]
L
=
x
.
shape
[
1
]
num_keep
=
max
(
1
,
int
(
L
*
(
1.0
-
self
.
prob
)))
keep_indices
=
torch
.
argsort
(
torch
.
randn
(
B
,
L
,
device
=
x
.
device
),
dim
=-
1
)[
:,
:
num_keep
]
if
self
.
ordered
:
# NOTE does not need to maintain patch order in typical transformer use,
# but possibly useful for debug / visualization
keep_indices
=
keep_indices
.
sort
(
dim
=-
1
)[
0
]
x
=
x
.
gather
(
1
,
keep_indices
.
unsqueeze
(
-
1
).
expand
((
-
1
,
-
1
)
+
x
.
shape
[
2
:]))
if
prefix_tokens
is
not
None
:
x
=
torch
.
cat
((
prefix_tokens
,
x
),
dim
=
1
)
if
self
.
return_indices
:
return
x
,
keep_indices
return
x
def
resample_abs_pos_embed
(
posemb
:
torch
.
Tensor
,
new_size
:
List
[
int
],
old_size
:
Optional
[
List
[
int
]]
=
None
,
num_prefix_tokens
:
int
=
1
,
interpolation
:
str
=
"bicubic"
,
antialias
:
bool
=
True
,
verbose
:
bool
=
False
,
):
# sort out sizes, assume square if old size not provided
num_pos_tokens
=
posemb
.
shape
[
1
]
num_new_tokens
=
new_size
[
0
]
*
new_size
[
1
]
+
num_prefix_tokens
if
num_new_tokens
==
num_pos_tokens
and
new_size
[
0
]
==
new_size
[
1
]:
return
posemb
if
old_size
is
None
:
hw
=
int
(
math
.
sqrt
(
num_pos_tokens
-
num_prefix_tokens
))
old_size
=
hw
,
hw
if
num_prefix_tokens
:
posemb_prefix
,
posemb
=
(
posemb
[:,
:
num_prefix_tokens
],
posemb
[:,
num_prefix_tokens
:],
)
else
:
posemb_prefix
,
posemb
=
None
,
posemb
# do the interpolation
embed_dim
=
posemb
.
shape
[
-
1
]
orig_dtype
=
posemb
.
dtype
posemb
=
posemb
.
float
()
# interpolate needs float32
posemb
=
posemb
.
reshape
(
1
,
old_size
[
0
],
old_size
[
1
],
-
1
).
permute
(
0
,
3
,
1
,
2
)
posemb
=
F
.
interpolate
(
posemb
,
size
=
new_size
,
mode
=
interpolation
,
antialias
=
antialias
)
posemb
=
posemb
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
1
,
-
1
,
embed_dim
)
posemb
=
posemb
.
to
(
orig_dtype
)
# add back extra (class, etc) prefix tokens
if
posemb_prefix
is
not
None
:
posemb
=
torch
.
cat
([
posemb_prefix
,
posemb
],
dim
=
1
)
if
not
torch
.
jit
.
is_scripting
()
and
verbose
:
logger
.
info
(
f
"Resized position embedding:
{
old_size
}
to
{
new_size
}
."
)
return
posemb
def
init_weights
(
self
):
if
self
.
pos_embed
is
not
None
:
trunc_normal_
(
self
.
pos_embed
,
std
=
self
.
pos_embed
.
shape
[
1
]
**
-
0.5
)
trunc_normal_
(
self
.
latent
,
std
=
self
.
latent_dim
**-
0.5
)
def
init_weights_vit_timm
(
module
:
nn
.
Module
,
name
:
str
=
""
)
->
None
:
"""ViT weight initialization, original timm impl (for reproducibility)"""
if
isinstance
(
module
,
nn
.
Linear
):
trunc_normal_
(
module
.
weight
,
std
=
0.02
)
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
elif
hasattr
(
module
,
"init_weights"
):
module
.
init_weights
()
class
VisionTransformer
(
nn
.
Module
):
"""Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
dynamic_img_size
:
Final
[
bool
]
def
__init__
(
self
,
img_size
:
Union
[
int
,
Tuple
[
int
,
int
]]
=
224
,
patch_size
:
Union
[
int
,
Tuple
[
int
,
int
]]
=
16
,
in_chans
:
int
=
3
,
num_classes
:
int
=
1000
,
global_pool
:
Literal
[
""
,
"avg"
,
"token"
,
"map"
]
=
"token"
,
embed_dim
:
int
=
768
,
depth
:
int
=
12
,
num_heads
:
int
=
12
,
mlp_ratio
:
float
=
4.0
,
qkv_bias
:
bool
=
True
,
qk_norm
:
bool
=
False
,
init_values
:
Optional
[
float
]
=
None
,
class_token
:
bool
=
True
,
no_embed_class
:
bool
=
False
,
reg_tokens
:
int
=
0
,
pre_norm
:
bool
=
False
,
fc_norm
:
Optional
[
bool
]
=
None
,
dynamic_img_size
:
bool
=
False
,
dynamic_img_pad
:
bool
=
False
,
drop_rate
:
float
=
0.0
,
pos_drop_rate
:
float
=
0.0
,
patch_drop_rate
:
float
=
0.0
,
proj_drop_rate
:
float
=
0.0
,
attn_drop_rate
:
float
=
0.0
,
drop_path_rate
:
float
=
0.0
,
weight_init
:
Literal
[
"skip"
,
"jax"
,
"jax_nlhb"
,
"moco"
,
""
]
=
""
,
embed_layer
:
Callable
=
PatchEmbed
,
_norm_layer
:
Optional
[
LayerType
]
=
None
,
_act_layer
:
Optional
[
LayerType
]
=
None
,
block_fn
:
Type
[
nn
.
Module
]
=
VisionTransformerBlock
,
mlp_layer
:
Type
[
nn
.
Module
]
=
Mlp
,
ignore_head
:
bool
=
False
,
)
->
None
:
"""
Args:
img_size: Input image size.
patch_size: Patch size.
in_chans: Number of image input channels.
num_classes: Mumber of classes for classification head.
global_pool: Type of global pooling for final sequence (default: 'token').
embed_dim: Transformer embedding dimension.
depth: Depth of transformer.
num_heads: Number of attention heads.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: Enable bias for qkv projections if True.
init_values: Layer-scale init values (layer-scale enabled if not None).
class_token: Use class token.
no_embed_class: Don't include position embeddings for class (or reg) tokens.
reg_tokens: Number of register tokens.
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
drop_rate: Head dropout rate.
pos_drop_rate: Position embedding dropout rate.
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme.
embed_layer: Patch embedding layer.
_norm_layer: Normalization layer.
_act_layer: MLP activation layer.
block_fn: Transformer block layer.
"""
super
().
__init__
()
assert
global_pool
in
(
""
,
"avg"
,
"token"
,
"map"
)
assert
class_token
or
global_pool
!=
"token"
use_fc_norm
=
global_pool
==
"avg"
if
fc_norm
is
None
else
fc_norm
# norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
# act_layer = get_act_layer(act_layer) or nn.GELU
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
act_layer
=
nn
.
GELU
self
.
num_classes
=
num_classes
self
.
global_pool
=
global_pool
self
.
num_features
=
self
.
embed_dim
=
(
embed_dim
# num_features for consistency with other models
)
self
.
num_prefix_tokens
=
1
if
class_token
else
0
self
.
num_prefix_tokens
+=
reg_tokens
self
.
num_reg_tokens
=
reg_tokens
self
.
has_class_token
=
class_token
self
.
no_embed_class
=
(
no_embed_class
# don't embed prefix positions (includes reg)
)
self
.
dynamic_img_size
=
dynamic_img_size
self
.
grad_checkpointing
=
False
self
.
ignore_head
=
ignore_head
embed_args
=
{}
if
dynamic_img_size
:
# flatten deferred until after pos embed
embed_args
.
update
(
dict
(
strict_img_size
=
False
,
output_fmt
=
"NHWC"
))
self
.
patch_embed
=
embed_layer
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
bias
=
not
pre_norm
,
# disable bias if pre-norm is used (e.g. CLIP)
dynamic_img_pad
=
dynamic_img_pad
,
**
embed_args
,
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
cls_token
=
(
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
if
class_token
else
None
)
self
.
reg_token
=
(
nn
.
Parameter
(
torch
.
zeros
(
1
,
reg_tokens
,
embed_dim
))
if
reg_tokens
else
None
)
embed_len
=
(
num_patches
if
no_embed_class
else
num_patches
+
self
.
num_prefix_tokens
)
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
randn
(
1
,
embed_len
,
embed_dim
)
*
0.02
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
pos_drop_rate
)
if
patch_drop_rate
>
0
:
self
.
patch_drop
=
PatchDropout
(
patch_drop_rate
,
num_prefix_tokens
=
self
.
num_prefix_tokens
,
)
else
:
self
.
patch_drop
=
nn
.
Identity
()
self
.
norm_pre
=
norm_layer
(
embed_dim
)
if
pre_norm
else
nn
.
Identity
()
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)
]
# stochastic depth decay rule
self
.
blocks
=
nn
.
Sequential
(
*
[
block_fn
(
dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_norm
=
qk_norm
,
init_values
=
init_values
,
proj_drop
=
proj_drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
i
],
norm_layer
=
norm_layer
,
act_layer
=
act_layer
,
mlp_layer
=
mlp_layer
,
)
for
i
in
range
(
depth
)
]
)
self
.
norm
=
norm_layer
(
embed_dim
)
if
not
use_fc_norm
else
nn
.
Identity
()
# Classifier Head
if
global_pool
==
"map"
:
AttentionPoolLatent
.
init_weights
=
init_weights
self
.
attn_pool
=
AttentionPoolLatent
(
self
.
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
norm_layer
=
norm_layer
,
)
else
:
self
.
attn_pool
=
None
self
.
fc_norm
=
norm_layer
(
embed_dim
)
if
use_fc_norm
else
nn
.
Identity
()
self
.
head_drop
=
nn
.
Dropout
(
drop_rate
)
self
.
head
=
(
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
)
if
weight_init
!=
"skip"
:
self
.
init_weights
(
weight_init
)
def
init_weights
(
self
,
mode
:
Literal
[
"jax"
,
"jax_nlhb"
,
"moco"
,
""
]
=
""
)
->
None
:
assert
mode
in
(
"jax"
,
"jax_nlhb"
,
"moco"
,
""
)
# head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
trunc_normal_
(
self
.
pos_embed
,
std
=
0.02
)
if
self
.
cls_token
is
not
None
:
nn
.
init
.
normal_
(
self
.
cls_token
,
std
=
1e-6
)
named_apply
(
init_weights_vit_timm
,
self
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
)
->
Set
:
return
{
"pos_embed"
,
"cls_token"
,
"dist_token"
}
@
torch
.
jit
.
ignore
def
group_matcher
(
self
,
coarse
:
bool
=
False
)
->
Dict
:
return
dict
(
stem
=
r
"^cls_token|pos_embed|patch_embed"
,
# stem and embed
blocks
=
[(
r
"^blocks\.(\d+)"
,
None
),
(
r
"^norm"
,
(
99999
,))],
)
@
torch
.
jit
.
ignore
def
get_classifier
(
self
)
->
nn
.
Module
:
return
self
.
head
def
reset_classifier
(
self
,
num_classes
:
int
,
global_pool
=
None
)
->
None
:
self
.
num_classes
=
num_classes
if
global_pool
is
not
None
:
assert
global_pool
in
(
""
,
"avg"
,
"token"
,
"map"
)
if
global_pool
==
"map"
and
self
.
attn_pool
is
None
:
assert
(
False
),
"Cannot currently add attention pooling in reset_classifier()."
elif
global_pool
!=
"map "
and
self
.
attn_pool
is
not
None
:
self
.
attn_pool
=
None
# remove attention pooling
self
.
global_pool
=
global_pool
self
.
head
=
(
nn
.
Linear
(
self
.
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
)
def
_pos_embed
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
dynamic_img_size
:
B
,
H
,
W
,
C
=
x
.
shape
pos_embed
=
resample_abs_pos_embed
(
self
.
pos_embed
,
[
H
,
W
],
num_prefix_tokens
=
0
if
self
.
no_embed_class
else
self
.
num_prefix_tokens
,
)
x
=
x
.
view
(
B
,
-
1
,
C
)
else
:
pos_embed
=
self
.
pos_embed
to_cat
=
[]
if
self
.
cls_token
is
not
None
:
to_cat
.
append
(
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
))
if
self
.
reg_token
is
not
None
:
to_cat
.
append
(
self
.
reg_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
))
if
self
.
no_embed_class
:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x
=
x
+
pos_embed
if
to_cat
:
x
=
torch
.
cat
(
to_cat
+
[
x
],
dim
=
1
)
else
:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if
to_cat
:
x
=
torch
.
cat
(
to_cat
+
[
x
],
dim
=
1
)
x
=
x
+
pos_embed
return
self
.
pos_drop
(
x
)
def
_intermediate_layers
(
self
,
x
:
torch
.
Tensor
,
n
:
Union
[
int
,
Sequence
]
=
1
,
)
->
List
[
torch
.
Tensor
]:
outputs
,
num_blocks
=
[],
len
(
self
.
blocks
)
take_indices
=
set
(
range
(
num_blocks
-
n
,
num_blocks
)
if
isinstance
(
n
,
int
)
else
n
)
# forward pass
x
=
self
.
patch_embed
(
x
)
x
=
self
.
_pos_embed
(
x
)
x
=
self
.
patch_drop
(
x
)
x
=
self
.
norm_pre
(
x
)
for
i
,
blk
in
enumerate
(
self
.
blocks
):
x
=
blk
(
x
)
if
i
in
take_indices
:
outputs
.
append
(
x
)
return
outputs
def
forward_features
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
patch_embed
(
x
)
x
=
self
.
_pos_embed
(
x
)
x
=
self
.
patch_drop
(
x
)
x
=
self
.
norm_pre
(
x
)
x
=
self
.
blocks
(
x
)
x
=
self
.
norm
(
x
)
return
x
def
forward_head
(
self
,
x
:
torch
.
Tensor
,
pre_logits
:
bool
=
False
)
->
torch
.
Tensor
:
if
self
.
attn_pool
is
not
None
:
x
=
self
.
attn_pool
(
x
)
elif
self
.
global_pool
==
"avg"
:
x
=
x
[:,
self
.
num_prefix_tokens
:].
mean
(
dim
=
1
)
elif
self
.
global_pool
:
x
=
x
[:,
0
]
# class token
x
=
self
.
fc_norm
(
x
)
x
=
self
.
head_drop
(
x
)
return
x
if
pre_logits
else
self
.
head
(
x
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
forward_features
(
x
)
if
not
self
.
ignore_head
:
x
=
self
.
forward_head
(
x
)
return
x
def
model_name_to_cls
(
cls_name
):
if
"MlpProjector"
in
cls_name
:
cls
=
MlpProjector
elif
"CLIPVisionTower"
in
cls_name
:
cls
=
CLIPVisionTower
elif
"VQ"
in
cls_name
:
cls
=
VQ_models
[
cls_name
]
elif
"vision_head"
in
cls_name
:
cls
=
vision_head
else
:
raise
ValueError
(
f
"class_name
{
cls_name
}
is invalid."
)
return
cls
class
vision_head
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
params
):
super
().
__init__
()
self
.
output_mlp_projector
=
torch
.
nn
.
Linear
(
params
[
"n_embed"
],
params
[
"image_token_embed"
]
)
self
.
vision_activation
=
torch
.
nn
.
GELU
()
self
.
vision_head
=
torch
.
nn
.
Linear
(
params
[
"image_token_embed"
],
params
[
"image_token_size"
]
)
def
forward
(
self
,
x
):
x
=
self
.
output_mlp_projector
(
x
)
x
=
self
.
vision_activation
(
x
)
x
=
self
.
vision_head
(
x
)
return
x
SigLIP_MODEL_CONFIG
=
{
"siglip_so400m_patch14_384"
:
{
"image_size"
:
336
,
"patch_size"
:
14
,
"width"
:
1152
,
"layers"
:
27
,
"heads"
:
16
,
"mlp_ratio"
:
3.7362
,
"global_pool"
:
"map"
,
"use_checkpoint"
:
False
,
},
"siglip_so400m_patch14_224"
:
{
"image_size"
:
224
,
"patch_size"
:
14
,
"width"
:
1152
,
"layers"
:
27
,
"heads"
:
16
,
"mlp_ratio"
:
3.7362
,
"global_pool"
:
"map"
,
"use_checkpoint"
:
False
,
},
"siglip_large_patch16_384"
:
{
"image_size"
:
384
,
"patch_size"
:
16
,
"width"
:
1024
,
"layers"
:
24
,
"heads"
:
16
,
"mlp_ratio"
:
4
,
"global_pool"
:
"map"
,
"use_checkpoint"
:
False
,
},
}
def
create_siglip_vit
(
model_name
:
str
=
"siglip_so400m_patch14_384"
,
image_size
:
int
=
384
,
select_layer
:
int
=
-
1
,
ckpt_path
:
str
=
""
,
**
kwargs
,
):
assert
(
model_name
in
SigLIP_MODEL_CONFIG
.
keys
()
),
f
"model name should be in
{
SigLIP_MODEL_CONFIG
.
keys
()
}
"
vision_cfg
=
SigLIPVisionCfg
(
**
SigLIP_MODEL_CONFIG
[
model_name
])
if
select_layer
<=
0
:
layers
=
min
(
vision_cfg
.
layers
,
vision_cfg
.
layers
+
select_layer
+
1
)
else
:
layers
=
min
(
vision_cfg
.
layers
,
select_layer
)
model
=
VisionTransformer
(
img_size
=
image_size
,
patch_size
=
vision_cfg
.
patch_size
,
embed_dim
=
vision_cfg
.
width
,
depth
=
layers
,
num_heads
=
vision_cfg
.
heads
,
mlp_ratio
=
vision_cfg
.
mlp_ratio
,
class_token
=
vision_cfg
.
class_token
,
global_pool
=
vision_cfg
.
global_pool
,
ignore_head
=
kwargs
.
get
(
"ignore_head"
,
True
),
weight_init
=
kwargs
.
get
(
"weight_init"
,
"skip"
),
num_classes
=
0
,
)
if
ckpt_path
:
state_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
)
incompatible_keys
=
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
print
(
f
"SigLIP-ViT restores from
{
ckpt_path
}
,
\n
"
f
"
\t
incompatible_keys:',
{
incompatible_keys
}
."
)
return
model
class
Normalize
(
torch
.
nn
.
Module
):
"""Normalize a tensor image with mean and standard deviation.
This transform does not support PIL Image.
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
channels, this transform will normalize each channel of the input
``torch.*Tensor`` i.e.,
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
.. note::
This transform acts out of place, i.e., it does not mutate the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation in-place.
"""
def
__init__
(
self
,
mean
,
std
,
inplace
=
False
):
super
().
__init__
()
# _log_api_usage_once(self)
self
.
mean
=
mean
self
.
std
=
std
self
.
inplace
=
inplace
def
forward
(
self
,
tensor
:
Tensor
)
->
Tensor
:
"""
Args:
tensor (Tensor): Tensor image to be normalized.
Returns:
Tensor: Normalized Tensor image.
"""
return
F
.
normalize
(
tensor
,
self
.
mean
,
self
.
std
,
self
.
inplace
)
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(mean=
{
self
.
mean
}
, std=
{
self
.
std
}
)"
class
CLIPVisionTower
(
nn
.
Module
):
def
__init__
(
self
,
model_name
:
str
=
"siglip_large_patch16_384"
,
image_size
:
Union
[
Tuple
[
int
,
int
],
int
]
=
336
,
select_feature
:
str
=
"patch"
,
select_layer
:
int
=
-
2
,
select_layers
:
list
=
None
,
ckpt_path
:
str
=
""
,
pixel_mean
:
Optional
[
List
[
float
]]
=
None
,
pixel_std
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
):
super
().
__init__
()
self
.
model_name
=
model_name
self
.
select_feature
=
select_feature
self
.
select_layer
=
select_layer
self
.
select_layers
=
select_layers
vision_tower_params
=
{
"model_name"
:
model_name
,
"image_size"
:
image_size
,
"ckpt_path"
:
ckpt_path
,
"select_layer"
:
select_layer
,
}
vision_tower_params
.
update
(
kwargs
)
self
.
vision_tower
,
self
.
forward_kwargs
=
self
.
build_vision_tower
(
vision_tower_params
)
if
pixel_mean
is
not
None
and
pixel_std
is
not
None
:
image_norm
=
Normalize
(
mean
=
pixel_mean
,
std
=
pixel_std
)
else
:
image_norm
=
None
self
.
image_norm
=
image_norm
@
property
def
device
(
self
)
->
torch
.
device
:
return
next
(
self
.
vision_tower
.
parameters
()).
device
@
property
def
dtype
(
self
):
return
next
(
self
.
vision_tower
.
parameters
()).
dtype
def
build_vision_tower
(
self
,
vision_tower_params
):
if
self
.
model_name
.
startswith
(
"siglip"
):
self
.
select_feature
=
"same"
vision_tower
=
create_siglip_vit
(
**
vision_tower_params
)
forward_kwargs
=
dict
()
elif
self
.
model_name
.
startswith
(
"sam"
):
# vision_tower = create_sam_vit(**vision_tower_params)
forward_kwargs
=
dict
()
else
:
# huggingface
from
transformers
import
CLIPVisionModel
vision_tower
=
CLIPVisionModel
.
from_pretrained
(
**
vision_tower_params
)
forward_kwargs
=
dict
(
output_hidden_states
=
True
)
return
vision_tower
,
forward_kwargs
def
feature_select
(
self
,
image_forward_outs
):
if
isinstance
(
image_forward_outs
,
torch
.
Tensor
):
# the output has been the self.select_layer"s features
image_features
=
image_forward_outs
else
:
image_features
=
image_forward_outs
.
hidden_states
[
self
.
select_layer
]
if
self
.
select_feature
==
"patch"
:
# if the output has cls_token
image_features
=
image_features
[:,
1
:]
elif
self
.
select_feature
==
"cls_patch"
:
image_features
=
image_features
elif
self
.
select_feature
==
"same"
:
image_features
=
image_features
else
:
raise
ValueError
(
f
"Unexpected select feature:
{
self
.
select_feature
}
"
)
return
image_features
def
forward
(
self
,
images
):
"""
Args:
images (torch.Tensor): [b, 3, H, W]
Returns:
image_features (torch.Tensor): [b, n_patch, d]
"""
if
self
.
image_norm
is
not
None
:
images
=
self
.
image_norm
(
images
)
image_forward_outs
=
self
.
vision_tower
(
images
,
**
self
.
forward_kwargs
)
image_features
=
self
.
feature_select
(
image_forward_outs
)
return
image_features
class
MlpProjector
(
nn
.
Module
):
def
__init__
(
self
,
cfg
):
super
().
__init__
()
self
.
cfg
=
cfg
if
cfg
[
"projector_type"
]
==
"identity"
:
modules
=
nn
.
Identity
()
elif
cfg
[
"projector_type"
]
==
"linear"
:
modules
=
nn
.
Linear
(
cfg
[
"input_dim"
],
cfg
[
"n_embed"
])
elif
cfg
[
"projector_type"
]
==
"mlp_gelu"
:
mlp_depth
=
cfg
.
get
(
"depth"
,
1
)
modules
=
[
nn
.
Linear
(
cfg
[
"input_dim"
],
cfg
[
"n_embed"
])]
for
_
in
range
(
1
,
mlp_depth
):
modules
.
append
(
nn
.
GELU
())
modules
.
append
(
nn
.
Linear
(
cfg
[
"n_embed"
],
cfg
[
"n_embed"
]))
modules
=
nn
.
Sequential
(
*
modules
)
elif
cfg
[
"projector_type"
]
==
"low_high_hybrid_split_mlp_gelu"
:
mlp_depth
=
cfg
.
get
(
"depth"
,
1
)
self
.
high_up_proj
=
nn
.
Linear
(
cfg
[
"input_dim"
],
cfg
[
"n_embed"
]
//
2
)
self
.
low_up_proj
=
nn
.
Linear
(
cfg
[
"input_dim"
],
cfg
[
"n_embed"
]
//
2
)
modules
=
[]
for
_
in
range
(
1
,
mlp_depth
):
modules
.
append
(
nn
.
GELU
())
modules
.
append
(
nn
.
Linear
(
cfg
[
"n_embed"
],
cfg
[
"n_embed"
]))
modules
=
nn
.
Sequential
(
*
modules
)
else
:
raise
ValueError
(
f
"Unknown projector type:
{
cfg
[
'projector_type'
]
}
"
)
self
.
layers
=
modules
def
forward
(
self
,
x_or_tuple
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
]
):
"""
Args:
x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
otherwise it is the feature from the single vision encoder.
Returns:
x (torch.Tensor): [b, s, c]
"""
if
isinstance
(
x_or_tuple
,
tuple
):
# self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
high_x
,
low_x
=
x_or_tuple
high_x
=
self
.
high_up_proj
(
high_x
)
low_x
=
self
.
low_up_proj
(
low_x
)
x
=
torch
.
concat
([
high_x
,
low_x
],
dim
=-
1
)
else
:
x
=
x_or_tuple
return
self
.
layers
(
x
)
class
LayerScale
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
init_values
:
float
=
1e-5
,
inplace
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
self
.
inplace
=
inplace
self
.
gamma
=
nn
.
Parameter
(
init_values
*
torch
.
ones
(
dim
))
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
x
.
mul_
(
self
.
gamma
)
if
self
.
inplace
else
x
*
self
.
gamma
# use torch.scaled_dot_product_attention where possible
_HAS_FUSED_ATTN
=
hasattr
(
torch
.
nn
.
functional
,
"scaled_dot_product_attention"
)
if
"TIMM_FUSED_ATTN"
in
os
.
environ
:
_USE_FUSED_ATTN
=
int
(
os
.
environ
[
"TIMM_FUSED_ATTN"
])
else
:
_USE_FUSED_ATTN
=
(
1
# 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
)
# Set to True if exporting a model with Same padding via ONNX
_EXPORTABLE
=
False
def
use_fused_attn
(
experimental
:
bool
=
False
)
->
bool
:
# NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0
if
not
_HAS_FUSED_ATTN
or
_EXPORTABLE
:
return
False
if
experimental
:
return
_USE_FUSED_ATTN
>
1
return
_USE_FUSED_ATTN
>
0
class
AttentionPoolLatent
(
nn
.
Module
):
"""Attention pooling w/ latent query"""
fused_attn
:
torch
.
jit
.
Final
[
bool
]
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
=
None
,
embed_dim
:
int
=
None
,
num_heads
:
int
=
8
,
feat_size
:
Optional
[
int
]
=
None
,
mlp_ratio
:
float
=
4.0
,
qkv_bias
:
bool
=
True
,
qk_norm
:
bool
=
False
,
latent_len
:
int
=
1
,
latent_dim
:
int
=
None
,
pos_embed
:
str
=
""
,
pool_type
:
str
=
"token"
,
norm_layer
:
Optional
[
nn
.
Module
]
=
None
,
drop
:
float
=
0.0
,
):
super
().
__init__
()
embed_dim
=
embed_dim
or
in_features
out_features
=
out_features
or
in_features
assert
embed_dim
%
num_heads
==
0
self
.
num_heads
=
num_heads
self
.
head_dim
=
embed_dim
//
num_heads
self
.
feat_size
=
feat_size
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
pool
=
pool_type
self
.
fused_attn
=
use_fused_attn
()
if
pos_embed
==
"abs"
:
assert
feat_size
is
not
None
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
feat_size
,
in_features
))
else
:
self
.
pos_embed
=
None
self
.
latent_dim
=
latent_dim
or
embed_dim
self
.
latent_len
=
latent_len
self
.
latent
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
self
.
latent_len
,
embed_dim
))
self
.
q
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
qkv_bias
)
self
.
kv
=
nn
.
Linear
(
embed_dim
,
embed_dim
*
2
,
bias
=
qkv_bias
)
self
.
q_norm
=
norm_layer
(
self
.
head_dim
)
if
qk_norm
else
nn
.
Identity
()
self
.
k_norm
=
norm_layer
(
self
.
head_dim
)
if
qk_norm
else
nn
.
Identity
()
self
.
proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
proj_drop
=
nn
.
Dropout
(
drop
)
self
.
norm
=
(
norm_layer
(
out_features
)
if
norm_layer
is
not
None
else
nn
.
Identity
()
)
self
.
mlp
=
Mlp
(
embed_dim
,
int
(
embed_dim
*
mlp_ratio
))
self
.
init_weights
()
def
init_weights
(
self
):
if
self
.
pos_embed
is
not
None
:
trunc_normal_tf_
(
self
.
pos_embed
,
std
=
self
.
pos_embed
.
shape
[
1
]
**
-
0.5
)
trunc_normal_tf_
(
self
.
latent
,
std
=
self
.
latent_dim
**-
0.5
)
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
if
self
.
pos_embed
is
not
None
:
# FIXME interpolate
x
=
x
+
self
.
pos_embed
.
unsqueeze
(
0
).
to
(
x
.
dtype
)
q_latent
=
self
.
latent
.
expand
(
B
,
-
1
,
-
1
)
q
=
(
self
.
q
(
q_latent
)
.
reshape
(
B
,
self
.
latent_len
,
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
)
kv
=
(
self
.
kv
(
x
)
.
reshape
(
B
,
N
,
2
,
self
.
num_heads
,
self
.
head_dim
)
.
permute
(
2
,
0
,
3
,
1
,
4
)
)
k
,
v
=
kv
.
unbind
(
0
)
q
,
k
=
self
.
q_norm
(
q
),
self
.
k_norm
(
k
)
if
self
.
fused_attn
:
x
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
)
else
:
q
=
q
*
self
.
scale
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
attn
=
attn
.
softmax
(
dim
=-
1
)
x
=
attn
@
v
x
=
x
.
transpose
(
1
,
2
).
reshape
(
B
,
self
.
latent_len
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
x
=
x
+
self
.
mlp
(
self
.
norm
(
x
))
# optional pool if latent seq_len > 1 and pooled output is desired
if
self
.
pool
==
"token"
:
x
=
x
[:,
0
]
elif
self
.
pool
==
"avg"
:
x
=
x
.
mean
(
1
)
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
=
3
,
ch
=
128
,
ch_mult
=
(
1
,
1
,
2
,
2
,
4
),
num_res_blocks
=
2
,
norm_type
=
"group"
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
z_channels
=
256
,
):
super
().
__init__
()
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# downsampling
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
conv_blocks
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
conv_block
=
nn
.
Module
()
# res & attn
res_block
=
nn
.
ModuleList
()
attn_block
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
_
in
range
(
self
.
num_res_blocks
):
res_block
.
append
(
ResnetBlock
(
block_in
,
block_out
,
dropout
=
dropout
,
norm_type
=
norm_type
)
)
block_in
=
block_out
if
i_level
==
self
.
num_resolutions
-
1
:
attn_block
.
append
(
AttnBlock
(
block_in
,
norm_type
))
conv_block
.
res
=
res_block
conv_block
.
attn
=
attn_block
# downsample
if
i_level
!=
self
.
num_resolutions
-
1
:
conv_block
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
self
.
conv_blocks
.
append
(
conv_block
)
# middle
self
.
mid
=
nn
.
ModuleList
()
self
.
mid
.
append
(
ResnetBlock
(
block_in
,
block_in
,
dropout
=
dropout
,
norm_type
=
norm_type
)
)
self
.
mid
.
append
(
AttnBlock
(
block_in
,
norm_type
=
norm_type
))
self
.
mid
.
append
(
ResnetBlock
(
block_in
,
block_in
,
dropout
=
dropout
,
norm_type
=
norm_type
)
)
# end
self
.
norm_out
=
Normalize
(
block_in
,
norm_type
)
self
.
conv_out
=
nn
.
Conv2d
(
block_in
,
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
h
=
self
.
conv_in
(
x
)
# downsampling
for
i_level
,
block
in
enumerate
(
self
.
conv_blocks
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
block
.
res
[
i_block
](
h
)
if
len
(
block
.
attn
)
>
0
:
h
=
block
.
attn
[
i_block
](
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
h
=
block
.
downsample
(
h
)
# middle
for
mid_block
in
self
.
mid
:
h
=
mid_block
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
z_channels
=
256
,
ch
=
128
,
ch_mult
=
(
1
,
1
,
2
,
2
,
4
),
num_res_blocks
=
2
,
norm_type
=
"group"
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
out_channels
=
3
,
):
super
().
__init__
()
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
# z to block_in
self
.
conv_in
=
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
ModuleList
()
self
.
mid
.
append
(
ResnetBlock
(
block_in
,
block_in
,
dropout
=
dropout
,
norm_type
=
norm_type
)
)
self
.
mid
.
append
(
AttnBlock
(
block_in
,
norm_type
=
norm_type
))
self
.
mid
.
append
(
ResnetBlock
(
block_in
,
block_in
,
dropout
=
dropout
,
norm_type
=
norm_type
)
)
# upsampling
self
.
conv_blocks
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
conv_block
=
nn
.
Module
()
# res & attn
res_block
=
nn
.
ModuleList
()
attn_block
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
_
in
range
(
self
.
num_res_blocks
+
1
):
res_block
.
append
(
ResnetBlock
(
block_in
,
block_out
,
dropout
=
dropout
,
norm_type
=
norm_type
)
)
block_in
=
block_out
if
i_level
==
self
.
num_resolutions
-
1
:
attn_block
.
append
(
AttnBlock
(
block_in
,
norm_type
))
conv_block
.
res
=
res_block
conv_block
.
attn
=
attn_block
# downsample
if
i_level
!=
0
:
conv_block
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
self
.
conv_blocks
.
append
(
conv_block
)
# end
self
.
norm_out
=
Normalize
(
block_in
,
norm_type
)
self
.
conv_out
=
nn
.
Conv2d
(
block_in
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
@
property
def
last_layer
(
self
):
return
self
.
conv_out
.
weight
def
forward
(
self
,
z
):
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
for
mid_block
in
self
.
mid
:
h
=
mid_block
(
h
)
# upsampling
for
i_level
,
block
in
enumerate
(
self
.
conv_blocks
):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
block
.
res
[
i_block
](
h
)
if
len
(
block
.
attn
)
>
0
:
h
=
block
.
attn
[
i_block
](
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
h
=
block
.
upsample
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
VectorQuantizer
(
nn
.
Module
):
def
__init__
(
self
,
n_e
,
e_dim
,
beta
,
entropy_loss_ratio
,
l2_norm
,
show_usage
):
super
().
__init__
()
self
.
n_e
=
n_e
self
.
e_dim
=
e_dim
self
.
beta
=
beta
self
.
entropy_loss_ratio
=
entropy_loss_ratio
self
.
l2_norm
=
l2_norm
self
.
show_usage
=
show_usage
self
.
embedding
=
nn
.
Embedding
(
self
.
n_e
,
self
.
e_dim
)
self
.
embedding
.
weight
.
data
.
uniform_
(
-
1.0
/
self
.
n_e
,
1.0
/
self
.
n_e
)
if
self
.
l2_norm
:
self
.
embedding
.
weight
.
data
=
F
.
normalize
(
self
.
embedding
.
weight
.
data
,
p
=
2
,
dim
=-
1
)
if
self
.
show_usage
:
# self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
self
.
codebook_used
=
nn
.
Parameter
(
torch
.
zeros
(
65536
))
def
forward
(
self
,
z
):
# reshape z -> (batch, height, width, channel) and flatten
z
=
torch
.
einsum
(
"b c h w -> b h w c"
,
z
).
contiguous
()
z_flattened
=
z
.
view
(
-
1
,
self
.
e_dim
)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
if
self
.
l2_norm
:
z
=
F
.
normalize
(
z
,
p
=
2
,
dim
=-
1
)
z_flattened
=
F
.
normalize
(
z_flattened
,
p
=
2
,
dim
=-
1
)
embedding
=
F
.
normalize
(
self
.
embedding
.
weight
,
p
=
2
,
dim
=-
1
)
else
:
embedding
=
self
.
embedding
.
weight
d
=
(
torch
.
sum
(
z_flattened
**
2
,
dim
=
1
,
keepdim
=
True
)
+
torch
.
sum
(
embedding
**
2
,
dim
=
1
)
-
2
*
torch
.
einsum
(
"bd,dn->bn"
,
z_flattened
,
torch
.
einsum
(
"n d -> d n"
,
embedding
)
)
)
min_encoding_indices
=
torch
.
argmin
(
d
,
dim
=
1
)
z_q
=
embedding
[
min_encoding_indices
].
view
(
z
.
shape
)
perplexity
=
None
min_encodings
=
None
vq_loss
=
None
commit_loss
=
None
entropy_loss
=
None
# compute loss for embedding
if
self
.
training
:
vq_loss
=
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
commit_loss
=
self
.
beta
*
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
entropy_loss
=
self
.
entropy_loss_ratio
*
compute_entropy_loss
(
-
d
)
# preserve gradients
z_q
=
z
+
(
z_q
-
z
).
detach
()
# reshape back to match original input shape
z_q
=
torch
.
einsum
(
"b h w c -> b c h w"
,
z_q
)
return
(
z_q
,
(
vq_loss
,
commit_loss
,
entropy_loss
),
(
perplexity
,
min_encodings
,
min_encoding_indices
),
)
def
get_codebook_entry
(
self
,
indices
,
shape
=
None
,
channel_first
=
True
):
# shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
if
self
.
l2_norm
:
embedding
=
F
.
normalize
(
self
.
embedding
.
weight
,
p
=
2
,
dim
=-
1
)
else
:
embedding
=
self
.
embedding
.
weight
z_q
=
embedding
[
indices
]
# (b*h*w, c)
if
shape
is
not
None
:
if
channel_first
:
z_q
=
z_q
.
reshape
(
shape
[
0
],
shape
[
2
],
shape
[
3
],
shape
[
1
])
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
else
:
z_q
=
z_q
.
view
(
shape
)
return
z_q
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
=
0.0
,
norm_type
=
"group"
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
,
norm_type
)
self
.
conv1
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
norm2
=
Normalize
(
out_channels
,
norm_type
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
conv2
=
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
norm_type
=
"group"
):
super
().
__init__
()
self
.
norm
=
Normalize
(
in_channels
,
norm_type
)
self
.
q
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
F
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
,
norm_type
=
"group"
):
assert
norm_type
in
[
"group"
,
"batch"
]
if
norm_type
==
"group"
:
return
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
elif
norm_type
==
"batch"
:
return
nn
.
SyncBatchNorm
(
in_channels
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
if
x
.
dtype
!=
torch
.
float32
:
x
=
F
.
interpolate
(
x
.
to
(
torch
.
float
),
scale_factor
=
2.0
,
mode
=
"nearest"
).
to
(
torch
.
bfloat16
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
F
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
def
compute_entropy_loss
(
affinity
,
loss_type
=
"softmax"
,
temperature
=
0.01
):
flat_affinity
=
affinity
.
reshape
(
-
1
,
affinity
.
shape
[
-
1
])
flat_affinity
/=
temperature
probs
=
F
.
softmax
(
flat_affinity
,
dim
=-
1
)
log_probs
=
F
.
log_softmax
(
flat_affinity
+
1e-5
,
dim
=-
1
)
if
loss_type
==
"softmax"
:
target_probs
=
probs
else
:
raise
ValueError
(
"Entropy loss {} not supported"
.
format
(
loss_type
))
avg_probs
=
torch
.
mean
(
target_probs
,
dim
=
0
)
avg_entropy
=
-
torch
.
sum
(
avg_probs
*
torch
.
log
(
avg_probs
+
1e-5
))
sample_entropy
=
-
torch
.
mean
(
torch
.
sum
(
target_probs
*
log_probs
,
dim
=-
1
))
loss
=
sample_entropy
-
avg_entropy
return
loss
class
VQModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ModelArgs
):
super
().
__init__
()
self
.
config
=
config
self
.
encoder
=
Encoder
(
ch_mult
=
config
.
encoder_ch_mult
,
z_channels
=
config
.
z_channels
,
dropout
=
config
.
dropout_p
,
)
self
.
decoder
=
Decoder
(
ch_mult
=
config
.
decoder_ch_mult
,
z_channels
=
config
.
z_channels
,
dropout
=
config
.
dropout_p
,
)
self
.
quantize
=
VectorQuantizer
(
config
.
codebook_size
,
config
.
codebook_embed_dim
,
config
.
commit_loss_beta
,
config
.
entropy_loss_ratio
,
config
.
codebook_l2_norm
,
config
.
codebook_show_usage
,
)
self
.
quant_conv
=
nn
.
Conv2d
(
config
.
z_channels
,
config
.
codebook_embed_dim
,
1
)
self
.
post_quant_conv
=
nn
.
Conv2d
(
config
.
codebook_embed_dim
,
config
.
z_channels
,
1
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
quant
,
emb_loss
,
info
=
self
.
quantize
(
h
)
return
quant
,
emb_loss
,
info
def
decode
(
self
,
quant
):
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
def
decode_code
(
self
,
code_b
,
shape
=
None
,
channel_first
=
True
):
quant_b
=
self
.
quantize
.
get_codebook_entry
(
code_b
,
shape
,
channel_first
)
dec
=
self
.
decode
(
quant_b
)
return
dec
def
forward
(
self
,
input
):
quant
,
diff
,
_
=
self
.
encode
(
input
)
dec
=
self
.
decode
(
quant
)
return
dec
,
diff
class
MultiModalityPreTrainedModel
(
PreTrainedModel
):
config_class
=
MultiModalityConfig
base_model_prefix
=
"multi_modality"
_no_split_modules
=
[]
_skip_keys_device_placement
=
"past_key_values"
# Copied and adapted from:
# https://github.com/deepseek-ai/Janus/tree/main/janus/models/modeling_vlm.py
class
MultiModalityCausalLM
(
MultiModalityPreTrainedModel
):
def
__init__
(
self
,
config
:
MultiModalityConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
(
config
)
vision_config
=
config
.
vision_config
vision_cls
=
model_name_to_cls
(
vision_config
.
cls
)
self
.
vision_model
=
vision_cls
(
**
vision_config
.
params
)
aligner_config
=
config
.
aligner_config
aligner_cls
=
model_name_to_cls
(
aligner_config
.
cls
)
self
.
aligner
=
aligner_cls
(
aligner_config
.
params
)
gen_vision_config
=
config
.
gen_vision_config
gen_vision_cls
=
model_name_to_cls
(
gen_vision_config
.
cls
)
self
.
gen_vision_model
=
gen_vision_cls
()
gen_aligner_config
=
config
.
gen_aligner_config
gen_aligner_cls
=
model_name_to_cls
(
gen_aligner_config
.
cls
)
self
.
gen_aligner
=
gen_aligner_cls
(
gen_aligner_config
.
params
)
gen_head_config
=
config
.
gen_head_config
gen_head_cls
=
model_name_to_cls
(
gen_head_config
.
cls
)
self
.
gen_head
=
gen_head_cls
(
gen_head_config
.
params
)
self
.
gen_embed
=
torch
.
nn
.
Embedding
(
gen_vision_config
.
params
[
"image_token_size"
],
gen_vision_config
.
params
[
"n_embed"
],
)
language_config
=
config
.
language_config
self
.
language_model
=
LlamaForCausalLM
(
language_config
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
prepare_images_seq_mask
(
self
,
input_ids
:
torch
.
Tensor
,
image_inputs
:
ImageInputs
)
->
Optional
[
torch
.
LongTensor
]:
images_seq_mask
=
torch
.
isin
(
input_ids
,
torch
.
tensor
(
image_inputs
.
pad_values
,
device
=
input_ids
.
device
)
)
if
images_seq_mask
.
sum
()
==
0
:
# sometimes image_inputs is not empty, but input_ids contain no image token because of prefix-cache
return
None
else
:
return
images_seq_mask
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
inputs_embeds
=
None
if
(
forward_batch
.
image_inputs
is
not
None
and
len
(
forward_batch
.
image_inputs
)
!=
0
and
forward_batch
.
image_inputs
[
0
]
is
not
None
):
image_inputs
=
forward_batch
.
image_inputs
[
0
]
images_seq_mask
=
self
.
prepare_images_seq_mask
(
input_ids
=
input_ids
,
image_inputs
=
image_inputs
)
if
images_seq_mask
is
not
None
:
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
inputs_embeds
=
self
.
prepare_inputs_embeds
(
input_ids
=
input_ids
,
pixel_values
=
image_inputs
.
pixel_values
,
images_seq_mask
=
images_seq_mask
,
images_emb_mask
=
image_inputs
.
images_emb_mask
,
)
input_ids
=
None
if
input_ids
is
not
None
:
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
return
self
.
language_model
(
input_ids
=
input_ids
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
get_embedding
=
False
,
)
def
prepare_inputs_embeds
(
self
,
input_ids
:
torch
.
LongTensor
,
pixel_values
:
torch
.
FloatTensor
,
images_seq_mask
:
torch
.
LongTensor
,
images_emb_mask
:
torch
.
BoolTensor
,
**
_kwargs
,
):
"""
Args:
input_ids (torch.LongTensor): [b, T]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
Returns:
input_embeds (torch.Tensor): [b, T, D]
"""
bs
,
n
=
pixel_values
.
shape
[
0
:
2
]
pixel_values
=
pixel_values
.
to
(
device
=
self
.
vision_model
.
device
,
dtype
=
self
.
vision_model
.
dtype
)
images
=
rearrange
(
pixel_values
,
"b n c h w -> (b n) c h w"
)
# [b x n, T2, D]
images_embeds
=
self
.
aligner
(
self
.
vision_model
(
images
))
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds
=
rearrange
(
images_embeds
,
"(b n) t d -> b (n t) d"
,
b
=
bs
,
n
=
n
)
# [b, n, T2] -> [b, n x T2]
images_emb_mask
=
rearrange
(
images_emb_mask
,
"b n t -> b (n t)"
)
# [b, T, D]
# ignore the image embeddings
input_ids
[
input_ids
<
0
]
=
0
inputs_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# replace with the image embeddings
inputs_embeds
[
images_seq_mask
]
=
images_embeds
[
images_emb_mask
]
return
inputs_embeds
def
prepare_gen_img_embeds
(
self
,
image_ids
:
torch
.
LongTensor
):
return
self
.
gen_aligner
(
self
.
gen_embed
(
image_ids
))
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
im_start_id
=
image_inputs
.
im_start_id
im_end_id
=
image_inputs
.
im_end_id
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
helper
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
return
helper
.
pad_input_tokens
(
input_ids
,
image_inputs
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq~"
in
name
or
"projector"
in
name
:
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
# skip generation sub model
if
"gen"
in
name
:
continue
# adapt to VisionAttention
name
=
name
.
replace
(
r
"self_attn.out_proj"
,
r
"self_attn.proj"
)
if
"vision_model.vision_tower"
in
name
:
name
=
name
.
replace
(
"attn.qkv"
,
"attn.qkv_proj"
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# replace the name and load with customized loader
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# # Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
None
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
AutoModel
.
register
(
config_class
=
MultiModalityConfig
,
model_class
=
MultiModalityCausalLM
)
EntryClass
=
[
MultiModalityCausalLM
]
test/srt/test_vision_openai_server.py
View file @
01090e8a
...
...
@@ -512,5 +512,29 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
cls
.
base_url
+=
"/v1"
class
TestJanusProServer
(
TestOpenAIVisionServer
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"deepseek-ai/Janus-Pro-7B"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--chat-template"
,
"janus-pro"
,
"--mem-fraction-static"
,
"0.4"
,
],
)
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
pass
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment