Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2b605ab1
Unverified
Commit
2b605ab1
authored
May 27, 2024
by
Li Bo
Committed by
GitHub
May 26, 2024
Browse files
[Feat/Fix] Refactoring Llava models into single file (#475)
parent
947bda73
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
118 additions
and
688 deletions
+118
-688
examples/usage/llava/http_llama3_llava_test.py
examples/usage/llava/http_llama3_llava_test.py
+4
-6
examples/usage/llava/http_qwen_llava_test.py
examples/usage/llava/http_qwen_llava_test.py
+6
-8
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+6
-1
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+102
-3
python/sglang/srt/models/llava_mistral.py
python/sglang/srt/models/llava_mistral.py
+0
-335
python/sglang/srt/models/llava_qwen.py
python/sglang/srt/models/llava_qwen.py
+0
-335
No files found.
examples/usage/llava/http_llama3_llava_test.py
View file @
2b605ab1
...
@@ -22,11 +22,7 @@ import aiohttp
...
@@ -22,11 +22,7 @@ import aiohttp
import
requests
import
requests
from
llava.conversation
import
(
from
llava.conversation
import
(
default_conversation
,
conv_templates
,
SeparatorStyle
,
conv_llava_llama_3
,
conv_llava_llama_3
,
conv_qwen
,
)
)
...
@@ -43,7 +39,8 @@ async def test_concurrent(args):
...
@@ -43,7 +39,8 @@ async def test_concurrent(args):
prompt
=
"<image>
\n
Please generate caption towards this image."
prompt
=
"<image>
\n
Please generate caption towards this image."
conv_template
=
copy
.
deepcopy
(
conv_llava_llama_3
)
conv_template
=
copy
.
deepcopy
(
conv_llava_llama_3
)
conv_template
.
append_message
(
role
=
"user"
,
message
=
prompt
)
conv_template
.
append_message
(
role
=
conv_template
.
roles
[
0
],
message
=
prompt
)
conv_template
.
append_message
(
role
=
conv_template
.
roles
[
1
],
message
=
None
)
prompt_with_template
=
conv_template
.
get_prompt
()
prompt_with_template
=
conv_template
.
get_prompt
()
response
=
[]
response
=
[]
for
i
in
range
(
1
):
for
i
in
range
(
1
):
...
@@ -74,7 +71,8 @@ def test_streaming(args):
...
@@ -74,7 +71,8 @@ def test_streaming(args):
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
prompt
=
"<image>
\n
Please generate caption towards this image."
prompt
=
"<image>
\n
Please generate caption towards this image."
conv_template
=
copy
.
deepcopy
(
conv_llava_llama_3
)
conv_template
=
copy
.
deepcopy
(
conv_llava_llama_3
)
conv_template
.
append_message
(
role
=
"user"
,
message
=
prompt
)
conv_template
.
append_message
(
role
=
conv_template
.
roles
[
0
],
message
=
prompt
)
conv_template
.
append_message
(
role
=
conv_template
.
roles
[
1
],
message
=
None
)
prompt_with_template
=
conv_template
.
get_prompt
()
prompt_with_template
=
conv_template
.
get_prompt
()
pload
=
{
pload
=
{
"text"
:
prompt_with_template
,
"text"
:
prompt_with_template
,
...
...
examples/usage/llava/http_qwen_llava_test.py
View file @
2b605ab1
...
@@ -22,11 +22,7 @@ import aiohttp
...
@@ -22,11 +22,7 @@ import aiohttp
import
requests
import
requests
from
llava.conversation
import
(
from
llava.conversation
import
(
default_conversation
,
conv_qwen
conv_templates
,
SeparatorStyle
,
conv_llava_llama_3
,
conv_qwen
,
)
)
...
@@ -43,7 +39,8 @@ async def test_concurrent(args):
...
@@ -43,7 +39,8 @@ async def test_concurrent(args):
prompt
=
"<image>
\n
Please generate caption towards this image."
prompt
=
"<image>
\n
Please generate caption towards this image."
conv_template
=
copy
.
deepcopy
(
conv_qwen
)
conv_template
=
copy
.
deepcopy
(
conv_qwen
)
conv_template
.
append_message
(
role
=
"user"
,
message
=
prompt
)
conv_template
.
append_message
(
role
=
conv_template
.
roles
[
0
],
message
=
prompt
)
conv_template
.
append_message
(
role
=
conv_template
.
roles
[
1
],
message
=
None
)
prompt_with_template
=
conv_template
.
get_prompt
()
prompt_with_template
=
conv_template
.
get_prompt
()
response
=
[]
response
=
[]
for
i
in
range
(
1
):
for
i
in
range
(
1
):
...
@@ -74,7 +71,8 @@ def test_streaming(args):
...
@@ -74,7 +71,8 @@ def test_streaming(args):
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
prompt
=
"<image>
\n
Please generate caption towards this image."
prompt
=
"<image>
\n
Please generate caption towards this image."
conv_template
=
copy
.
deepcopy
(
conv_qwen
)
conv_template
=
copy
.
deepcopy
(
conv_qwen
)
conv_template
.
append_message
(
role
=
"user"
,
message
=
prompt
)
conv_template
.
append_message
(
role
=
conv_template
.
roles
[
0
],
message
=
prompt
)
conv_template
.
append_message
(
role
=
conv_template
.
roles
[
1
],
message
=
None
)
prompt_with_template
=
conv_template
.
get_prompt
()
prompt_with_template
=
conv_template
.
get_prompt
()
pload
=
{
pload
=
{
"text"
:
prompt_with_template
,
"text"
:
prompt_with_template
,
...
@@ -113,5 +111,5 @@ if __name__ == "__main__":
...
@@ -113,5 +111,5 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
#
asyncio.run(test_concurrent(args))
asyncio
.
run
(
test_concurrent
(
args
))
test_streaming
(
args
)
test_streaming
(
args
)
python/sglang/srt/managers/router/model_runner.py
View file @
2b605ab1
...
@@ -421,7 +421,12 @@ def import_model_classes():
...
@@ -421,7 +421,12 @@ def import_model_classes():
if
not
ispkg
:
if
not
ispkg
:
module
=
importlib
.
import_module
(
name
)
module
=
importlib
.
import_module
(
name
)
if
hasattr
(
module
,
"EntryClass"
):
if
hasattr
(
module
,
"EntryClass"
):
model_arch_name_to_cls
[
module
.
EntryClass
.
__name__
]
=
module
.
EntryClass
entry
=
module
.
EntryClass
if
isinstance
(
entry
,
list
):
# To support multiple model classes in one module
for
cls
in
entry
:
model_arch_name_to_cls
[
cls
.
__name__
]
=
cls
else
:
model_arch_name_to_cls
[
entry
.
__name__
]
=
entry
return
model_arch_name_to_cls
return
model_arch_name_to_cls
...
...
python/sglang/srt/models/llava.py
View file @
2b605ab1
...
@@ -5,7 +5,7 @@ from typing import List, Iterable, Optional, Tuple
...
@@ -5,7 +5,7 @@ from typing import List, Iterable, Optional, Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
transformers
import
CLIPVisionModel
,
CLIPVisionConfig
,
LlavaConfig
,
Qwen2Config
,
MistralConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -18,6 +18,8 @@ from sglang.srt.mm_utils import (
...
@@ -18,6 +18,8 @@ from sglang.srt.mm_utils import (
unpad_image_shape
,
unpad_image_shape
,
)
)
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
from
sglang.srt.models.mistral
import
MistralForCausalLM
class
LlavaLlamaForCausalLM
(
nn
.
Module
):
class
LlavaLlamaForCausalLM
(
nn
.
Module
):
...
@@ -287,8 +289,101 @@ class LlavaLlamaForCausalLM(nn.Module):
...
@@ -287,8 +289,101 @@ class LlavaLlamaForCausalLM(nn.Module):
return
self
.
image_size
//
self
.
patch_size
return
self
.
image_size
//
self
.
patch_size
first_call
=
True
class
LlavaQwenForCausalLM
(
LlavaLlamaForCausalLM
):
def
__init__
(
self
,
config
:
LlavaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
(
config
,
quant_config
=
quant_config
)
self
.
config
=
config
self
.
vision_tower
=
None
if
getattr
(
self
.
config
,
"vision_config"
,
None
)
is
None
:
self
.
config
.
vision_config
=
CLIPVisionConfig
(
self
.
config
.
mm_vision_tower
)
if
getattr
(
self
.
config
,
"text_config"
,
None
)
is
None
:
self
.
config
.
text_config
=
Qwen2Config
(
self
.
config
.
_name_or_path
)
self
.
config
.
vision_config
.
hidden_size
=
config
.
mm_hidden_size
self
.
config
.
text_config
.
hidden_size
=
config
.
hidden_size
if
getattr
(
self
.
config
,
"projector_hidden_act"
,
None
)
is
None
:
self
.
config
.
projector_hidden_act
=
"gelu"
if
getattr
(
self
.
config
,
"image_token_index"
,
None
)
is
None
:
self
.
config
.
image_token_index
=
151646
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
config
)
self
.
language_model
=
Qwen2ForCausalLM
(
config
,
quant_config
=
quant_config
)
if
"unpad"
in
getattr
(
config
,
"mm_patch_merge_type"
,
""
):
self
.
language_model
.
model
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
,
dtype
=
torch
.
float16
)
)
def
pad_input_ids
(
self
,
input_ids
,
pad_value
,
pt_shape
=
None
,
image_size
=
None
):
new_image_feature_len
=
self
.
image_feature_len
# now only support spatial_unpad + anyres
if
self
.
mm_patch_merge_type
.
startswith
(
"spatial"
):
height
=
width
=
self
.
num_patches_per_side
if
pt_shape
[
0
]
>
1
:
if
self
.
image_aspect_ratio
==
"anyres"
:
num_patch_width
,
num_patch_height
=
get_anyres_image_grid_shape
(
image_size
,
self
.
image_grid_pinpoints
,
self
.
vision_tower
.
config
.
image_size
,
)
if
"unpad"
in
self
.
mm_patch_merge_type
:
h
=
num_patch_height
*
height
w
=
num_patch_width
*
width
new_h
,
new_w
=
unpad_image_shape
(
h
,
w
,
image_size
)
new_image_feature_len
+=
new_h
*
(
new_w
+
1
)
pad_ids
=
pad_value
*
(
(
new_image_feature_len
+
len
(
pad_value
))
//
len
(
pad_value
)
)
offset
=
input_ids
.
index
(
self
.
config
.
image_token_index
)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids
=
(
input_ids
[:
offset
]
+
pad_ids
[:
new_image_feature_len
]
+
input_ids
[
offset
+
1
:]
)
return
new_input_ids
,
offset
class
LlavaMistralForCausalLM
(
LlavaLlamaForCausalLM
):
def
__init__
(
self
,
config
:
LlavaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
(
config
,
quant_config
=
quant_config
)
self
.
config
=
config
self
.
vision_tower
=
None
if
getattr
(
self
.
config
,
"vision_config"
,
None
)
is
None
:
self
.
config
.
vision_config
=
CLIPVisionConfig
(
self
.
config
.
mm_vision_tower
)
if
getattr
(
self
.
config
,
"text_config"
,
None
)
is
None
:
self
.
config
.
text_config
=
MistralConfig
(
self
.
config
.
_name_or_path
)
self
.
config
.
vision_config
.
hidden_size
=
config
.
mm_hidden_size
self
.
config
.
text_config
.
hidden_size
=
config
.
hidden_size
if
getattr
(
self
.
config
,
"projector_hidden_act"
,
None
)
is
None
:
self
.
config
.
projector_hidden_act
=
"gelu"
if
getattr
(
self
.
config
,
"image_token_index"
,
None
)
is
None
:
self
.
config
.
image_token_index
=
32000
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
config
)
self
.
language_model
=
MistralForCausalLM
(
config
,
quant_config
=
quant_config
)
if
"unpad"
in
getattr
(
config
,
"mm_patch_merge_type"
,
""
):
self
.
language_model
.
model
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
,
dtype
=
torch
.
float16
)
)
first_call
=
True
def
clip_vision_embed_forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
def
clip_vision_embed_forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
batch_size
=
pixel_values
.
shape
[
0
]
...
@@ -319,4 +414,8 @@ def monkey_path_clip_vision_embed_forward():
...
@@ -319,4 +414,8 @@ def monkey_path_clip_vision_embed_forward():
)
)
EntryClass
=
LlavaLlamaForCausalLM
EntryClass
=
[
LlavaLlamaForCausalLM
,
LlavaQwenForCausalLM
,
LlavaMistralForCausalLM
]
python/sglang/srt/models/llava_mistral.py
deleted
100644 → 0
View file @
947bda73
"""Inference-only LLaVa model compatible with HuggingFace weights."""
from
typing
import
List
,
Iterable
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
torch
import
nn
from
transformers
import
CLIPVisionConfig
,
CLIPVisionModel
,
LlavaConfig
,
MistralConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.mm_utils
import
(
get_anyres_image_grid_shape
,
unpad_image
,
unpad_image_shape
,
)
from
sglang.srt.models.mistral
import
MistralForCausalLM
class
LlavaMistralForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlavaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
vision_tower
=
None
if
getattr
(
self
.
config
,
"vision_config"
,
None
)
is
None
:
self
.
config
.
vision_config
=
CLIPVisionConfig
(
self
.
config
.
mm_vision_tower
)
if
getattr
(
self
.
config
,
"text_config"
,
None
)
is
None
:
self
.
config
.
text_config
=
MistralConfig
(
self
.
config
.
_name_or_path
)
self
.
config
.
vision_config
.
hidden_size
=
config
.
mm_hidden_size
self
.
config
.
text_config
.
hidden_size
=
config
.
hidden_size
if
getattr
(
self
.
config
,
"projector_hidden_act"
,
None
)
is
None
:
self
.
config
.
projector_hidden_act
=
"gelu"
if
getattr
(
self
.
config
,
"image_token_index"
,
None
)
is
None
:
self
.
config
.
image_token_index
=
32000
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
config
)
self
.
language_model
=
MistralForCausalLM
(
config
,
quant_config
=
quant_config
)
if
"unpad"
in
getattr
(
config
,
"mm_patch_merge_type"
,
""
):
self
.
language_model
.
model
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
,
dtype
=
torch
.
float16
)
)
def
pad_input_ids
(
self
,
input_ids
,
pad_value
,
pt_shape
=
None
,
image_size
=
None
):
new_image_feature_len
=
self
.
image_feature_len
# now only support spatial_unpad + anyres
if
self
.
mm_patch_merge_type
.
startswith
(
"spatial"
):
height
=
width
=
self
.
num_patches_per_side
if
pt_shape
[
0
]
>
1
:
if
self
.
image_aspect_ratio
==
"anyres"
:
num_patch_width
,
num_patch_height
=
get_anyres_image_grid_shape
(
image_size
,
self
.
image_grid_pinpoints
,
self
.
vision_tower
.
config
.
image_size
,
)
if
"unpad"
in
self
.
mm_patch_merge_type
:
h
=
num_patch_height
*
height
w
=
num_patch_width
*
width
new_h
,
new_w
=
unpad_image_shape
(
h
,
w
,
image_size
)
new_image_feature_len
+=
new_h
*
(
new_w
+
1
)
pad_ids
=
pad_value
*
(
(
new_image_feature_len
+
len
(
pad_value
))
//
len
(
pad_value
)
)
offset
=
input_ids
.
index
(
self
.
config
.
image_token_index
)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids
=
(
input_ids
[:
offset
]
+
pad_ids
[:
new_image_feature_len
]
+
input_ids
[
offset
+
1
:]
)
return
new_input_ids
,
offset
def
encode_images
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image_outputs
=
self
.
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature
=
image_outputs
.
hidden_states
[
self
.
vision_feature_layer
]
if
self
.
vision_feature_select_strategy
in
[
"default"
,
"patch"
]:
selected_image_feature
=
selected_image_feature
[:,
1
:]
elif
self
.
vision_feature_select_strategy
==
"full"
:
selected_image_feature
=
selected_image_feature
else
:
raise
ValueError
(
f
"Unexpected select feature strategy:
{
self
.
config
.
vision_feature_select_strategy
}
"
)
image_features
=
self
.
multi_modal_projector
(
selected_image_feature
)
return
image_features
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
pixel_values
:
Optional
[
List
[
Optional
[
np
.
array
]]]
=
None
,
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
bs
=
input_metadata
.
batch_size
# Embed text input
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# Embed vision input
need_vision
=
(
(
positions
[
input_metadata
.
extend_start_loc
]
<
self
.
image_feature_len
)
.
cpu
()
.
numpy
()
)
# FIXME: We need to substract the length of the system prompt
has_pixel
=
np
.
array
([
pixel_values
[
i
]
is
not
None
for
i
in
range
(
bs
)])
need_vision
=
need_vision
&
has_pixel
if
need_vision
.
any
():
pixel_values
=
[
pixel_values
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
image_sizes
=
[
image_sizes
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
########## Encode Image ########
if
pixel_values
[
0
].
ndim
==
4
:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
np
.
concatenate
(
pixel_values
,
axis
=
0
)
# ndim=4
concat_images
=
torch
.
tensor
(
np
.
concatenate
(
pixel_values
,
axis
=
0
),
device
=
self
.
vision_tower
.
device
,
)
image_features
=
self
.
encode_images
(
concat_images
)
split_sizes
=
[
image
.
shape
[
0
]
for
image
in
pixel_values
]
image_features
=
torch
.
split
(
image_features
,
split_sizes
,
dim
=
0
)
# hd image_features: BS, num_patch, 576, 4096
else
:
# normal pixel: BS, C=3, H=336, W=336
pixel_values
=
torch
.
tensor
(
np
.
array
(
pixel_values
),
device
=
self
.
vision_tower
.
device
)
image_features
=
self
.
encode_images
(
pixel_values
)
# image_features: BS, 576, 4096
if
self
.
mm_patch_merge_type
.
startswith
(
"spatial"
):
new_image_features
=
[]
for
image_idx
,
image_feature
in
enumerate
(
image_features
):
if
image_feature
.
shape
[
0
]
>
1
:
base_image_feature
=
image_feature
[
0
]
image_feature
=
image_feature
[
1
:]
height
=
width
=
self
.
num_patches_per_side
assert
height
*
width
==
base_image_feature
.
shape
[
0
]
if
self
.
image_aspect_ratio
==
"anyres"
:
(
num_patch_width
,
num_patch_height
,
)
=
get_anyres_image_grid_shape
(
image_sizes
[
image_idx
],
self
.
image_grid_pinpoints
,
self
.
vision_tower
.
config
.
image_size
,
)
image_feature
=
image_feature
.
view
(
num_patch_height
,
num_patch_width
,
height
,
width
,
-
1
)
else
:
raise
NotImplementedError
()
if
"unpad"
in
self
.
mm_patch_merge_type
:
image_feature
=
image_feature
.
permute
(
4
,
0
,
2
,
1
,
3
).
contiguous
()
image_feature
=
image_feature
.
flatten
(
1
,
2
).
flatten
(
2
,
3
)
image_feature
=
unpad_image
(
image_feature
,
image_sizes
[
image_idx
]
)
image_feature
=
torch
.
cat
(
(
image_feature
,
self
.
language_model
.
model
.
image_newline
[
:,
None
,
None
].
expand
(
*
image_feature
.
shape
[:
-
1
],
1
),
),
dim
=-
1
,
)
image_feature
=
image_feature
.
flatten
(
1
,
2
).
transpose
(
0
,
1
)
else
:
image_feature
=
image_feature
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
image_feature
=
image_feature
.
flatten
(
0
,
3
)
image_feature
=
torch
.
cat
(
(
base_image_feature
,
image_feature
),
dim
=
0
)
else
:
image_feature
=
image_feature
[
0
]
if
"unpad"
in
self
.
mm_patch_merge_type
:
image_feature
=
torch
.
cat
(
(
image_feature
,
self
.
language_model
.
model
.
image_newline
[
None
],
),
dim
=
0
,
)
new_image_features
.
append
(
image_feature
)
image_features
=
new_image_features
extend_start_loc_cpu
=
input_metadata
.
extend_start_loc
.
cpu
().
numpy
()
pt
=
0
for
i
in
range
(
bs
):
if
not
need_vision
[
i
]:
continue
start_idx
=
extend_start_loc_cpu
[
i
]
pad_len
,
pad_dim
=
image_features
[
pt
].
shape
# 576, 4096
dim
=
input_embeds
.
shape
[
1
]
assert
(
pad_dim
==
dim
),
"invalid pad_dim={}, input_embed_dim={}!"
.
format
(
pad_dim
,
dim
)
# Fill in the placeholder for the image
try
:
input_embeds
[
start_idx
+
image_offsets
[
i
]
:
start_idx
+
image_offsets
[
i
]
+
pad_len
]
=
image_features
[
pt
]
except
RuntimeError
as
e
:
print
(
f
"RuntimeError in llava image encoding:
{
e
}
"
)
print
(
input_embeds
.
shape
)
print
(
start_idx
,
image_offsets
[
i
])
pt
+=
1
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
=
input_embeds
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path
=
self
.
config
.
mm_vision_tower
self
.
vision_tower
=
CLIPVisionModel
.
from_pretrained
(
vision_path
,
torch_dtype
=
torch
.
float16
).
cuda
()
self
.
vision_tower
.
eval
()
self
.
vision_feature_layer
=
self
.
config
.
mm_vision_select_layer
self
.
vision_feature_select_strategy
=
self
.
config
.
mm_vision_select_feature
self
.
image_size
=
self
.
vision_tower
.
config
.
image_size
self
.
patch_size
=
self
.
vision_tower
.
config
.
patch_size
self
.
mm_patch_merge_type
=
getattr
(
self
.
config
,
"mm_patch_merge_type"
,
"flat"
)
self
.
image_aspect_ratio
=
getattr
(
self
.
config
,
"image_aspect_ratio"
,
"square"
)
self
.
image_grid_pinpoints
=
getattr
(
self
.
config
,
"image_grid_pinpoints"
,
None
)
self
.
image_feature_len
=
int
((
self
.
image_size
/
self
.
patch_size
)
**
2
)
if
self
.
vision_feature_select_strategy
==
"patch"
:
pass
elif
self
.
vision_feature_select_strategy
==
"cls_patch"
:
self
.
image_feature_len
+=
1
else
:
raise
ValueError
(
f
"Unexpected select feature:
{
self
.
select_feature
}
"
)
# load mm_projector
projector_weights
=
{
"model.mm_projector.0"
:
"multi_modal_projector.linear_1"
,
"model.mm_projector.2"
:
"multi_modal_projector.linear_2"
,
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
params_dict
=
dict
(
self
.
named_parameters
())
weights
=
list
(
weights
)
for
name
,
loaded_weight
in
weights
:
# FIXME: why projector weights read two times?
if
"projector"
in
name
or
"vision_tower"
in
name
:
for
weight_name
,
param_name
in
projector_weights
.
items
():
if
weight_name
in
name
:
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load language model
self
.
language_model
.
load_weights
(
weights
)
monkey_path_clip_vision_embed_forward
()
@
property
def
num_patches_per_side
(
self
):
return
self
.
image_size
//
self
.
patch_size
first_call
=
True
def
clip_vision_embed_forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global
first_call
if
first_call
:
self
.
patch_embedding
.
cpu
().
float
()
first_call
=
False
pixel_values
=
pixel_values
.
to
(
dtype
=
torch
.
float32
,
device
=
"cpu"
)
patch_embeds
=
self
.
patch_embedding
(
pixel_values
).
cuda
().
half
()
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
embeddings
=
embeddings
+
self
.
position_embedding
(
self
.
position_ids
)
return
embeddings
def
monkey_path_clip_vision_embed_forward
():
import
transformers
setattr
(
transformers
.
models
.
clip
.
modeling_clip
.
CLIPVisionEmbeddings
,
"forward"
,
clip_vision_embed_forward
,
)
EntryClass
=
LlavaMistralForCausalLM
python/sglang/srt/models/llava_qwen.py
deleted
100644 → 0
View file @
947bda73
"""Inference-only LLaVa model compatible with HuggingFace weights."""
from
typing
import
List
,
Iterable
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
torch
import
nn
from
transformers
import
CLIPVisionConfig
,
CLIPVisionModel
,
LlavaConfig
,
Qwen2Config
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.mm_utils
import
(
get_anyres_image_grid_shape
,
unpad_image
,
unpad_image_shape
,
)
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
class
LlavaQwenForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlavaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
vision_tower
=
None
if
getattr
(
self
.
config
,
"vision_config"
,
None
)
is
None
:
self
.
config
.
vision_config
=
CLIPVisionConfig
(
self
.
config
.
mm_vision_tower
)
if
getattr
(
self
.
config
,
"text_config"
,
None
)
is
None
:
self
.
config
.
text_config
=
Qwen2Config
(
self
.
config
.
_name_or_path
)
self
.
config
.
vision_config
.
hidden_size
=
config
.
mm_hidden_size
self
.
config
.
text_config
.
hidden_size
=
config
.
hidden_size
if
getattr
(
self
.
config
,
"projector_hidden_act"
,
None
)
is
None
:
self
.
config
.
projector_hidden_act
=
"gelu"
if
getattr
(
self
.
config
,
"image_token_index"
,
None
)
is
None
:
self
.
config
.
image_token_index
=
151646
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
config
)
self
.
language_model
=
Qwen2ForCausalLM
(
config
,
quant_config
=
quant_config
)
if
"unpad"
in
getattr
(
config
,
"mm_patch_merge_type"
,
""
):
self
.
language_model
.
model
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
,
dtype
=
torch
.
float16
)
)
def
pad_input_ids
(
self
,
input_ids
,
pad_value
,
pt_shape
=
None
,
image_size
=
None
):
new_image_feature_len
=
self
.
image_feature_len
# now only support spatial_unpad + anyres
if
self
.
mm_patch_merge_type
.
startswith
(
"spatial"
):
height
=
width
=
self
.
num_patches_per_side
if
pt_shape
[
0
]
>
1
:
if
self
.
image_aspect_ratio
==
"anyres"
:
num_patch_width
,
num_patch_height
=
get_anyres_image_grid_shape
(
image_size
,
self
.
image_grid_pinpoints
,
self
.
vision_tower
.
config
.
image_size
,
)
if
"unpad"
in
self
.
mm_patch_merge_type
:
h
=
num_patch_height
*
height
w
=
num_patch_width
*
width
new_h
,
new_w
=
unpad_image_shape
(
h
,
w
,
image_size
)
new_image_feature_len
+=
new_h
*
(
new_w
+
1
)
pad_ids
=
pad_value
*
(
(
new_image_feature_len
+
len
(
pad_value
))
//
len
(
pad_value
)
)
offset
=
input_ids
.
index
(
self
.
config
.
image_token_index
)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids
=
(
input_ids
[:
offset
]
+
pad_ids
[:
new_image_feature_len
]
+
input_ids
[
offset
+
1
:]
)
return
new_input_ids
,
offset
def
encode_images
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image_outputs
=
self
.
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature
=
image_outputs
.
hidden_states
[
self
.
vision_feature_layer
]
if
self
.
vision_feature_select_strategy
in
[
"default"
,
"patch"
]:
selected_image_feature
=
selected_image_feature
[:,
1
:]
elif
self
.
vision_feature_select_strategy
==
"full"
:
selected_image_feature
=
selected_image_feature
else
:
raise
ValueError
(
f
"Unexpected select feature strategy:
{
self
.
config
.
vision_feature_select_strategy
}
"
)
image_features
=
self
.
multi_modal_projector
(
selected_image_feature
)
return
image_features
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
pixel_values
:
Optional
[
List
[
Optional
[
np
.
array
]]]
=
None
,
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
bs
=
input_metadata
.
batch_size
# Embed text input
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# Embed vision input
need_vision
=
(
(
positions
[
input_metadata
.
extend_start_loc
]
<
self
.
image_feature_len
)
.
cpu
()
.
numpy
()
)
# FIXME: We need to substract the length of the system prompt
has_pixel
=
np
.
array
([
pixel_values
[
i
]
is
not
None
for
i
in
range
(
bs
)])
need_vision
=
need_vision
&
has_pixel
if
need_vision
.
any
():
pixel_values
=
[
pixel_values
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
image_sizes
=
[
image_sizes
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
########## Encode Image ########
if
pixel_values
[
0
].
ndim
==
4
:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
np
.
concatenate
(
pixel_values
,
axis
=
0
)
# ndim=4
concat_images
=
torch
.
tensor
(
np
.
concatenate
(
pixel_values
,
axis
=
0
),
device
=
self
.
vision_tower
.
device
,
)
image_features
=
self
.
encode_images
(
concat_images
)
split_sizes
=
[
image
.
shape
[
0
]
for
image
in
pixel_values
]
image_features
=
torch
.
split
(
image_features
,
split_sizes
,
dim
=
0
)
# hd image_features: BS, num_patch, 576, 4096
else
:
# normal pixel: BS, C=3, H=336, W=336
pixel_values
=
torch
.
tensor
(
np
.
array
(
pixel_values
),
device
=
self
.
vision_tower
.
device
)
image_features
=
self
.
encode_images
(
pixel_values
)
# image_features: BS, 576, 4096
if
self
.
mm_patch_merge_type
.
startswith
(
"spatial"
):
new_image_features
=
[]
for
image_idx
,
image_feature
in
enumerate
(
image_features
):
if
image_feature
.
shape
[
0
]
>
1
:
base_image_feature
=
image_feature
[
0
]
image_feature
=
image_feature
[
1
:]
height
=
width
=
self
.
num_patches_per_side
assert
height
*
width
==
base_image_feature
.
shape
[
0
]
if
self
.
image_aspect_ratio
==
"anyres"
:
(
num_patch_width
,
num_patch_height
,
)
=
get_anyres_image_grid_shape
(
image_sizes
[
image_idx
],
self
.
image_grid_pinpoints
,
self
.
vision_tower
.
config
.
image_size
,
)
image_feature
=
image_feature
.
view
(
num_patch_height
,
num_patch_width
,
height
,
width
,
-
1
)
else
:
raise
NotImplementedError
()
if
"unpad"
in
self
.
mm_patch_merge_type
:
image_feature
=
image_feature
.
permute
(
4
,
0
,
2
,
1
,
3
).
contiguous
()
image_feature
=
image_feature
.
flatten
(
1
,
2
).
flatten
(
2
,
3
)
image_feature
=
unpad_image
(
image_feature
,
image_sizes
[
image_idx
]
)
image_feature
=
torch
.
cat
(
(
image_feature
,
self
.
language_model
.
model
.
image_newline
[
:,
None
,
None
].
expand
(
*
image_feature
.
shape
[:
-
1
],
1
),
),
dim
=-
1
,
)
image_feature
=
image_feature
.
flatten
(
1
,
2
).
transpose
(
0
,
1
)
else
:
image_feature
=
image_feature
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
image_feature
=
image_feature
.
flatten
(
0
,
3
)
image_feature
=
torch
.
cat
(
(
base_image_feature
,
image_feature
),
dim
=
0
)
else
:
image_feature
=
image_feature
[
0
]
if
"unpad"
in
self
.
mm_patch_merge_type
:
image_feature
=
torch
.
cat
(
(
image_feature
,
self
.
language_model
.
model
.
image_newline
[
None
],
),
dim
=
0
,
)
new_image_features
.
append
(
image_feature
)
image_features
=
new_image_features
extend_start_loc_cpu
=
input_metadata
.
extend_start_loc
.
cpu
().
numpy
()
pt
=
0
for
i
in
range
(
bs
):
if
not
need_vision
[
i
]:
continue
start_idx
=
extend_start_loc_cpu
[
i
]
pad_len
,
pad_dim
=
image_features
[
pt
].
shape
# 576, 4096
dim
=
input_embeds
.
shape
[
1
]
assert
(
pad_dim
==
dim
),
"invalid pad_dim={}, input_embed_dim={}!"
.
format
(
pad_dim
,
dim
)
# Fill in the placeholder for the image
try
:
input_embeds
[
start_idx
+
image_offsets
[
i
]
:
start_idx
+
image_offsets
[
i
]
+
pad_len
]
=
image_features
[
pt
]
except
RuntimeError
as
e
:
print
(
f
"RuntimeError in llava image encoding:
{
e
}
"
)
print
(
input_embeds
.
shape
)
print
(
start_idx
,
image_offsets
[
i
])
pt
+=
1
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
=
input_embeds
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path
=
self
.
config
.
mm_vision_tower
self
.
vision_tower
=
CLIPVisionModel
.
from_pretrained
(
vision_path
,
torch_dtype
=
torch
.
float16
).
cuda
()
self
.
vision_tower
.
eval
()
self
.
vision_feature_layer
=
self
.
config
.
mm_vision_select_layer
self
.
vision_feature_select_strategy
=
self
.
config
.
mm_vision_select_feature
self
.
image_size
=
self
.
vision_tower
.
config
.
image_size
self
.
patch_size
=
self
.
vision_tower
.
config
.
patch_size
self
.
mm_patch_merge_type
=
getattr
(
self
.
config
,
"mm_patch_merge_type"
,
"flat"
)
self
.
image_aspect_ratio
=
getattr
(
self
.
config
,
"image_aspect_ratio"
,
"square"
)
self
.
image_grid_pinpoints
=
getattr
(
self
.
config
,
"image_grid_pinpoints"
,
None
)
self
.
image_feature_len
=
int
((
self
.
image_size
/
self
.
patch_size
)
**
2
)
if
self
.
vision_feature_select_strategy
==
"patch"
:
pass
elif
self
.
vision_feature_select_strategy
==
"cls_patch"
:
self
.
image_feature_len
+=
1
else
:
raise
ValueError
(
f
"Unexpected select feature:
{
self
.
select_feature
}
"
)
# load mm_projector
projector_weights
=
{
"model.mm_projector.0"
:
"multi_modal_projector.linear_1"
,
"model.mm_projector.2"
:
"multi_modal_projector.linear_2"
,
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
params_dict
=
dict
(
self
.
named_parameters
())
weights
=
list
(
weights
)
for
name
,
loaded_weight
in
weights
:
# FIXME: why projector weights read two times?
if
"projector"
in
name
or
"vision_tower"
in
name
:
for
weight_name
,
param_name
in
projector_weights
.
items
():
if
weight_name
in
name
:
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load language model
self
.
language_model
.
load_weights
(
weights
)
monkey_path_clip_vision_embed_forward
()
@
property
def
num_patches_per_side
(
self
):
return
self
.
image_size
//
self
.
patch_size
first_call
=
True
def
clip_vision_embed_forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global
first_call
if
first_call
:
self
.
patch_embedding
.
cpu
().
float
()
first_call
=
False
pixel_values
=
pixel_values
.
to
(
dtype
=
torch
.
float32
,
device
=
"cpu"
)
patch_embeds
=
self
.
patch_embedding
(
pixel_values
).
cuda
().
half
()
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
embeddings
=
embeddings
+
self
.
position_embedding
(
self
.
position_ids
)
return
embeddings
def
monkey_path_clip_vision_embed_forward
():
import
transformers
setattr
(
transformers
.
models
.
clip
.
modeling_clip
.
CLIPVisionEmbeddings
,
"forward"
,
clip_vision_embed_forward
,
)
EntryClass
=
LlavaQwenForCausalLM
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment