Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c6576e82
Unverified
Commit
c6576e82
authored
Jan 24, 2024
by
shiyi.c_98
Committed by
GitHub
Jan 24, 2024
Browse files
Llava-hd Support (#92)
Co-authored-by:
Haotian Liu
<
liuhaotian.cn@gmail.com
>
parent
99258181
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
430 additions
and
39 deletions
+430
-39
examples/quick_start/srt_example_llava.py
examples/quick_start/srt_example_llava.py
+4
-2
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+1
-0
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+3
-0
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+2
-1
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+3
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+26
-9
python/sglang/srt/mm_utils.py
python/sglang/srt/mm_utils.py
+251
-0
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+139
-25
python/sglang/srt/server.py
python/sglang/srt/server.py
+0
-1
No files found.
examples/quick_start/srt_example_llava.py
View file @
c6576e82
...
@@ -7,8 +7,10 @@ def image_qa(s, image_path, question):
...
@@ -7,8 +7,10 @@ def image_qa(s, image_path, question):
s
+=
sgl
.
assistant
(
sgl
.
gen
(
"answer"
))
s
+=
sgl
.
assistant
(
sgl
.
gen
(
"answer"
))
runtime
=
sgl
.
Runtime
(
model_path
=
"liuhaotian/llava-v1.5-7b"
,
# runtime = sgl.Runtime(model_path="liuhaotian/llava-v1.5-7b",
tokenizer_path
=
"llava-hf/llava-1.5-7b-hf"
)
# tokenizer_path="llava-hf/llava-1.5-7b-hf")
runtime
=
sgl
.
Runtime
(
model_path
=
"llava-internal/llava-v1.6-7b-hd-224px_3x2-preview-20230103"
,
tokenizer_path
=
"llava-internal/llava-v1.6-7b-hd-224px_3x2-preview-20230103-tokenizer"
)
sgl
.
set_default_backend
(
runtime
)
sgl
.
set_default_backend
(
runtime
)
...
...
python/pyproject.toml
View file @
c6576e82
...
@@ -18,7 +18,7 @@ dependencies = [
...
@@ -18,7 +18,7 @@ dependencies = [
]
]
[project.optional-dependencies]
[project.optional-dependencies]
srt
=
[
"fastapi"
,
"psutil"
,
"rpyc"
,
"torch"
,
"uvloop"
,
"uvicorn"
,
srt
=
[
"aiohttp"
,
"fastapi"
,
"psutil"
,
"rpyc"
,
"torch"
,
"uvloop"
,
"uvicorn"
,
"zmq"
,
"vllm>=0.2.5"
,
"interegular"
,
"lark"
,
"numba"
,
"zmq"
,
"vllm>=0.2.5"
,
"interegular"
,
"lark"
,
"numba"
,
"pydantic"
,
"diskcache"
,
"cloudpickle"
]
"pydantic"
,
"diskcache"
,
"cloudpickle"
]
openai
=
[
"openai>=1.0"
,
"numpy"
]
openai
=
[
"openai>=1.0"
,
"numpy"
]
...
...
python/sglang/srt/managers/io_struct.py
View file @
c6576e82
...
@@ -62,6 +62,7 @@ class TokenizedGenerateReqInput:
...
@@ -62,6 +62,7 @@ class TokenizedGenerateReqInput:
input_ids
:
List
[
int
]
input_ids
:
List
[
int
]
pixel_values
:
List
[
float
]
pixel_values
:
List
[
float
]
image_hash
:
int
image_hash
:
int
image_size
:
List
[
int
]
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
return_logprob
:
bool
return_logprob
:
bool
logprob_start_len
:
int
logprob_start_len
:
int
...
...
python/sglang/srt/managers/router/infer_batch.py
View file @
c6576e82
...
@@ -26,6 +26,7 @@ class Req:
...
@@ -26,6 +26,7 @@ class Req:
self
.
input_ids
=
[]
self
.
input_ids
=
[]
self
.
output_ids
=
[]
self
.
output_ids
=
[]
self
.
pixel_values
=
None
self
.
pixel_values
=
None
self
.
image_size
=
None
self
.
image_offset
=
0
self
.
image_offset
=
0
self
.
sampling_params
=
None
self
.
sampling_params
=
None
self
.
return_logprob
=
False
self
.
return_logprob
=
False
...
@@ -104,6 +105,7 @@ class Batch:
...
@@ -104,6 +105,7 @@ class Batch:
# for multimodal
# for multimodal
pixel_values
:
List
[
torch
.
Tensor
]
=
None
pixel_values
:
List
[
torch
.
Tensor
]
=
None
image_sizes
:
List
[
List
[
int
]]
=
None
image_offsets
:
List
[
int
]
=
None
image_offsets
:
List
[
int
]
=
None
# other arguments for control
# other arguments for control
...
@@ -195,6 +197,7 @@ class Batch:
...
@@ -195,6 +197,7 @@ class Batch:
flatten_input_ids
,
dtype
=
torch
.
int32
,
device
=
device
flatten_input_ids
,
dtype
=
torch
.
int32
,
device
=
device
)
)
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
self
.
image_offsets
=
[
self
.
image_offsets
=
[
r
.
image_offset
-
p_len
for
r
,
p_len
in
zip
(
reqs
,
prefix_lens
)
r
.
image_offset
-
p_len
for
r
,
p_len
in
zip
(
reqs
,
prefix_lens
)
]
]
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
c6576e82
...
@@ -203,6 +203,7 @@ class ModelRpcServer(rpyc.Service):
...
@@ -203,6 +203,7 @@ class ModelRpcServer(rpyc.Service):
req
=
Req
(
recv_req
.
rid
)
req
=
Req
(
recv_req
.
rid
)
req
.
input_ids
=
recv_req
.
input_ids
req
.
input_ids
=
recv_req
.
input_ids
req
.
pixel_values
=
recv_req
.
pixel_values
req
.
pixel_values
=
recv_req
.
pixel_values
req
.
image_size
=
recv_req
.
image_size
if
req
.
pixel_values
is
not
None
:
if
req
.
pixel_values
is
not
None
:
pad_value
=
[
pad_value
=
[
(
recv_req
.
image_hash
)
%
self
.
model_config
.
vocab_size
,
(
recv_req
.
image_hash
)
%
self
.
model_config
.
vocab_size
,
...
@@ -211,7 +212,7 @@ class ModelRpcServer(rpyc.Service):
...
@@ -211,7 +212,7 @@ class ModelRpcServer(rpyc.Service):
(
recv_req
.
image_hash
>>
64
)
%
self
.
model_config
.
vocab_size
,
(
recv_req
.
image_hash
>>
64
)
%
self
.
model_config
.
vocab_size
,
]
]
req
.
input_ids
,
req
.
image_offset
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
input_ids
,
req
.
image_offset
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
input_ids
,
pad_value
req
.
input_ids
,
pad_value
,
req
.
pixel_values
.
shape
,
req
.
image_size
)
)
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
return_logprob
=
recv_req
.
return_logprob
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
c6576e82
...
@@ -409,6 +409,7 @@ class ModelRunner:
...
@@ -409,6 +409,7 @@ class ModelRunner:
self
,
self
,
input_ids
,
input_ids
,
pixel_values
,
pixel_values
,
image_sizes
,
image_offsets
,
image_offsets
,
req_pool_indices
,
req_pool_indices
,
seq_lens
,
seq_lens
,
...
@@ -433,6 +434,7 @@ class ModelRunner:
...
@@ -433,6 +434,7 @@ class ModelRunner:
input_metadata
.
positions
,
input_metadata
.
positions
,
input_metadata
,
input_metadata
,
pixel_values
,
pixel_values
,
image_sizes
,
image_offsets
,
image_offsets
,
)
)
...
@@ -441,6 +443,7 @@ class ModelRunner:
...
@@ -441,6 +443,7 @@ class ModelRunner:
kwargs
=
{
kwargs
=
{
"input_ids"
:
batch
.
input_ids
,
"input_ids"
:
batch
.
input_ids
,
"pixel_values"
:
batch
.
pixel_values
,
"pixel_values"
:
batch
.
pixel_values
,
"image_sizes"
:
batch
.
image_sizes
,
"image_offsets"
:
batch
.
image_offsets
,
"image_offsets"
:
batch
.
image_offsets
,
"req_pool_indices"
:
batch
.
req_pool_indices
,
"req_pool_indices"
:
batch
.
req_pool_indices
,
"seq_lens"
:
batch
.
seq_lens
,
"seq_lens"
:
batch
.
seq_lens
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
c6576e82
...
@@ -20,6 +20,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -20,6 +20,7 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput
,
GenerateReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
)
)
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
get_exception_traceback
,
is_multimodal_model
,
load_image
from
sglang.srt.utils
import
get_exception_traceback
,
is_multimodal_model
,
load_image
...
@@ -48,14 +49,25 @@ def init_global_processor(server_args: ServerArgs):
...
@@ -48,14 +49,25 @@ def init_global_processor(server_args: ServerArgs):
)
)
def
get_pixel_values
(
image_data
,
processor
=
None
):
def
get_pixel_values
(
image_data
,
model_cfg
,
processor
=
None
):
image_aspect_ratio
=
getattr
(
model_cfg
,
"image_aspect_ratio"
,
None
)
try
:
try
:
processor
=
processor
or
global_processor
processor
=
processor
or
global_processor
image
=
load_image
(
image_data
)
image
=
load_image
(
image_data
)
image_hash
=
hash
(
image_data
)
image_hash
=
hash
(
image_data
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
if
image_aspect_ratio
==
"pad"
:
image
=
expand2square
(
image
,
tuple
(
int
(
x
*
255
)
for
x
in
processor
.
image_processor
.
image_mean
)
)
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
elif
image_aspect_ratio
==
"anyres"
:
pixel_values
=
process_anyres_image
(
image
,
processor
.
image_processor
,
model_cfg
.
image_grid_pinpoints
)
else
:
pixel_values
=
processor
.
image_processor
(
image
)[
"pixel_values"
][
0
]
pixel_values
=
pixel_values
.
astype
(
np
.
float16
)
pixel_values
=
pixel_values
.
astype
(
np
.
float16
)
return
pixel_values
,
image_hash
return
pixel_values
,
image_hash
,
image
.
size
except
Exception
:
except
Exception
:
print
(
"Exception in TokenizerManager:
\n
"
+
get_exception_traceback
())
print
(
"Exception in TokenizerManager:
\n
"
+
get_exception_traceback
())
...
@@ -77,6 +89,7 @@ class TokenizerManager:
...
@@ -77,6 +89,7 @@ class TokenizerManager:
self
.
hf_config
=
get_config
(
self
.
hf_config
=
get_config
(
self
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
self
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
)
)
self
.
context_len
=
get_context_length
(
self
.
hf_config
)
self
.
context_len
=
get_context_length
(
self
.
hf_config
)
if
is_multimodal_model
(
self
.
model_path
):
if
is_multimodal_model
(
self
.
model_path
):
...
@@ -104,10 +117,10 @@ class TokenizerManager:
...
@@ -104,10 +117,10 @@ class TokenizerManager:
if
self
.
executor
is
not
None
:
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
return
await
loop
.
run_in_executor
(
self
.
executor
,
get_pixel_values
,
image_data
self
.
executor
,
get_pixel_values
,
image_data
,
self
.
hf_config
)
)
else
:
else
:
return
get_pixel_values
(
image_data
,
self
.
processor
)
return
get_pixel_values
(
image_data
,
self
.
hf_config
,
self
.
processor
)
async
def
generate_request
(
self
,
obj
:
GenerateReqInput
):
async
def
generate_request
(
self
,
obj
:
GenerateReqInput
):
if
self
.
to_create_loop
:
if
self
.
to_create_loop
:
...
@@ -123,14 +136,17 @@ class TokenizerManager:
...
@@ -123,14 +136,17 @@ class TokenizerManager:
sampling_params
.
normalize
(
self
.
tokenizer
)
sampling_params
.
normalize
(
self
.
tokenizer
)
sampling_params
.
verify
()
sampling_params
.
verify
()
if
obj
.
image_data
is
None
:
if
obj
.
image_data
is
None
:
pixel_values
,
image_hash
=
None
,
None
pixel_values
,
image_hash
,
image_size
=
None
,
None
,
None
else
:
else
:
pixel_values
,
image_hash
=
await
self
.
get_pixel_values
(
obj
.
image_data
)
pixel_values
,
image_hash
,
image_size
=
await
self
.
get_pixel_values
(
obj
.
image_data
)
tokenized_obj
=
TokenizedGenerateReqInput
(
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
=
rid
,
rid
=
rid
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
image_hash
=
image_hash
,
image_hash
=
image_hash
,
image_size
=
image_size
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
return_logprob
=
obj
.
return_logprob
,
return_logprob
=
obj
.
return_logprob
,
logprob_start_len
=
obj
.
logprob_start_len
,
logprob_start_len
=
obj
.
logprob_start_len
,
...
@@ -162,9 +178,9 @@ class TokenizerManager:
...
@@ -162,9 +178,9 @@ class TokenizerManager:
sampling_params
.
normalize
(
self
.
tokenizer
)
sampling_params
.
normalize
(
self
.
tokenizer
)
sampling_params
.
verify
()
sampling_params
.
verify
()
if
obj
.
image_data
[
i
]
is
None
:
if
obj
.
image_data
[
i
]
is
None
:
pixel_values
,
image_hash
=
None
,
None
pixel_values
,
image_hash
,
image_size
=
None
,
None
,
None
else
:
else
:
pixel_values
,
image_hash
=
await
self
.
get_pixel_values
(
pixel_values
,
image_hash
,
image_size
=
await
self
.
get_pixel_values
(
obj
.
image_data
[
i
]
obj
.
image_data
[
i
]
)
)
tokenized_obj
=
TokenizedGenerateReqInput
(
tokenized_obj
=
TokenizedGenerateReqInput
(
...
@@ -172,6 +188,7 @@ class TokenizerManager:
...
@@ -172,6 +188,7 @@ class TokenizerManager:
input_ids
=
input_ids
,
input_ids
=
input_ids
,
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
image_hash
=
image_hash
,
image_hash
=
image_hash
,
image_size
=
image_size
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
return_logprob
=
obj
.
return_logprob
[
i
],
return_logprob
=
obj
.
return_logprob
[
i
],
logprob_start_len
=
obj
.
logprob_start_len
[
i
],
logprob_start_len
=
obj
.
logprob_start_len
[
i
],
...
...
python/sglang/srt/mm_utils.py
0 → 100644
View file @
c6576e82
# Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
import
ast
import
base64
import
math
from
io
import
BytesIO
import
numpy
as
np
from
PIL
import
Image
def
select_best_resolution
(
original_size
,
possible_resolutions
):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width
,
original_height
=
original_size
best_fit
=
None
max_effective_resolution
=
0
min_wasted_resolution
=
float
(
"inf"
)
for
width
,
height
in
possible_resolutions
:
scale
=
min
(
width
/
original_width
,
height
/
original_height
)
downscaled_width
,
downscaled_height
=
int
(
original_width
*
scale
),
int
(
original_height
*
scale
)
effective_resolution
=
min
(
downscaled_width
*
downscaled_height
,
original_width
*
original_height
)
wasted_resolution
=
(
width
*
height
)
-
effective_resolution
if
effective_resolution
>
max_effective_resolution
or
(
effective_resolution
==
max_effective_resolution
and
wasted_resolution
<
min_wasted_resolution
):
max_effective_resolution
=
effective_resolution
min_wasted_resolution
=
wasted_resolution
best_fit
=
(
width
,
height
)
return
best_fit
def
resize_and_pad_image
(
image
,
target_resolution
):
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
"""
original_width
,
original_height
=
image
.
size
target_width
,
target_height
=
target_resolution
scale_w
=
target_width
/
original_width
scale_h
=
target_height
/
original_height
if
scale_w
<
scale_h
:
new_width
=
target_width
new_height
=
min
(
math
.
ceil
(
original_height
*
scale_w
),
target_height
)
else
:
new_height
=
target_height
new_width
=
min
(
math
.
ceil
(
original_width
*
scale_h
),
target_width
)
# Resize the image
resized_image
=
image
.
resize
((
new_width
,
new_height
))
new_image
=
Image
.
new
(
"RGB"
,
(
target_width
,
target_height
),
(
0
,
0
,
0
))
paste_x
=
(
target_width
-
new_width
)
//
2
paste_y
=
(
target_height
-
new_height
)
//
2
new_image
.
paste
(
resized_image
,
(
paste_x
,
paste_y
))
return
new_image
def
divide_to_patches
(
image
,
patch_size
):
"""
Divides an image into patches of a specified size.
Args:
image (PIL.Image.Image): The input image.
patch_size (int): The size of each patch.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
patches
=
[]
width
,
height
=
image
.
size
for
i
in
range
(
0
,
height
,
patch_size
):
for
j
in
range
(
0
,
width
,
patch_size
):
box
=
(
j
,
i
,
j
+
patch_size
,
i
+
patch_size
)
patch
=
image
.
crop
(
box
)
patches
.
append
(
patch
)
return
patches
def
get_anyres_image_grid_shape
(
image_size
,
grid_pinpoints
,
patch_size
):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (tuple): The size of the input image in the format (width, height).
grid_pinpoints (str): A string representation of a list of possible resolutions.
patch_size (int): The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if
type
(
grid_pinpoints
)
is
list
:
possible_resolutions
=
grid_pinpoints
else
:
possible_resolutions
=
ast
.
literal_eval
(
grid_pinpoints
)
width
,
height
=
select_best_resolution
(
image_size
,
possible_resolutions
)
return
width
//
patch_size
,
height
//
patch_size
def
process_anyres_image
(
image
,
processor
,
grid_pinpoints
):
"""
Process an image with variable resolutions.
Args:
image (PIL.Image.Image): The input image to be processed.
processor: The image processor object.
grid_pinpoints (str): A string representation of a list of possible resolutions.
Returns:
np.array: An np array containing the processed image patches.
"""
if
type
(
grid_pinpoints
)
is
list
:
possible_resolutions
=
grid_pinpoints
else
:
possible_resolutions
=
ast
.
literal_eval
(
grid_pinpoints
)
best_resolution
=
select_best_resolution
(
image
.
size
,
possible_resolutions
)
image_padded
=
resize_and_pad_image
(
image
,
best_resolution
)
patches
=
divide_to_patches
(
image_padded
,
processor
.
crop_size
[
"height"
])
image_original_resize
=
image
.
resize
(
(
processor
.
size
[
"shortest_edge"
],
processor
.
size
[
"shortest_edge"
])
)
image_patches
=
[
image_original_resize
]
+
patches
image_patches
=
[
processor
.
preprocess
(
image_patch
)[
"pixel_values"
][
0
]
for
image_patch
in
image_patches
]
return
np
.
stack
(
image_patches
,
axis
=
0
)
def
load_image_from_base64
(
image
):
return
Image
.
open
(
BytesIO
(
base64
.
b64decode
(
image
)))
def
expand2square
(
pil_img
,
background_color
):
width
,
height
=
pil_img
.
size
if
width
==
height
:
return
pil_img
elif
width
>
height
:
result
=
Image
.
new
(
pil_img
.
mode
,
(
width
,
width
),
background_color
)
result
.
paste
(
pil_img
,
(
0
,
(
width
-
height
)
//
2
))
return
result
else
:
result
=
Image
.
new
(
pil_img
.
mode
,
(
height
,
height
),
background_color
)
result
.
paste
(
pil_img
,
((
height
-
width
)
//
2
,
0
))
return
result
def
unpad_image
(
tensor
,
original_size
):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of the image (height, width).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width
,
original_height
=
original_size
current_height
,
current_width
=
tensor
.
shape
[
1
:]
original_aspect_ratio
=
original_width
/
original_height
current_aspect_ratio
=
current_width
/
current_height
if
original_aspect_ratio
>
current_aspect_ratio
:
scale_factor
=
current_width
/
original_width
new_height
=
int
(
original_height
*
scale_factor
)
padding
=
(
current_height
-
new_height
)
//
2
unpadded_tensor
=
tensor
[:,
padding
:
current_height
-
padding
,
:]
else
:
scale_factor
=
current_height
/
original_height
new_width
=
int
(
original_width
*
scale_factor
)
padding
=
(
current_width
-
new_width
)
//
2
unpadded_tensor
=
tensor
[:,
:,
padding
:
current_width
-
padding
]
return
unpadded_tensor
def
unpad_image_shape
(
current_height
,
current_width
,
original_size
):
"""
Unpads a PyTorch tensor of a padded and resized image
and returns the new shape.
"""
original_width
,
original_height
=
original_size
original_aspect_ratio
=
original_width
/
original_height
current_aspect_ratio
=
current_width
/
current_height
if
original_aspect_ratio
>
current_aspect_ratio
:
scale_factor
=
current_width
/
original_width
new_height
=
int
(
original_height
*
scale_factor
)
padding
=
(
current_height
-
new_height
)
//
2
new_shape
=
(
current_height
-
2
*
padding
,
current_width
)
else
:
scale_factor
=
current_height
/
original_height
new_width
=
int
(
original_width
*
scale_factor
)
padding
=
(
current_width
-
new_width
)
//
2
new_shape
=
(
current_height
,
current_width
-
2
*
padding
)
return
new_shape
def
process_images
(
images
,
image_processor
,
model_cfg
):
image_aspect_ratio
=
getattr
(
model_cfg
,
"image_aspect_ratio"
,
None
)
new_images
=
[]
if
image_aspect_ratio
==
"pad"
:
for
image
in
images
:
image
=
expand2square
(
image
,
tuple
(
int
(
x
*
255
)
for
x
in
image_processor
.
image_mean
)
)
image
=
image_processor
.
preprocess
(
image
)[
"pixel_values"
][
0
]
new_images
.
append
(
image
)
elif
image_aspect_ratio
==
"anyres"
:
for
image
in
images
:
image
=
process_anyres_image
(
image
,
image_processor
,
model_cfg
.
image_grid_pinpoints
)
new_images
.
append
(
image
)
else
:
return
image_processor
(
images
)[
"pixel_values"
]
if
all
(
x
.
shape
==
new_images
[
0
].
shape
for
x
in
new_images
):
new_images
=
np
.
stack
(
new_images
,
axis
=
0
)
return
new_images
python/sglang/srt/models/llava.py
View file @
c6576e82
"""Inference-only LLaVa model compatible with HuggingFace weights."""
"""Inference-only LLaVa model compatible with HuggingFace weights."""
import
json
from
typing
import
List
,
Optional
import
os
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.mm_utils
import
(
get_anyres_image_grid_shape
,
unpad_image
,
unpad_image_shape
,
)
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
torch
import
nn
from
torch
import
nn
from
transformers
import
CLIP
ImageProcessor
,
CLIPVisionModel
,
LlavaConfig
from
transformers
import
CLIP
VisionModel
,
LlamaConfig
,
LlavaConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.weight_utils
import
(
from
vllm.model_executor.weight_utils
import
(
...
@@ -31,26 +34,64 @@ class LlavaLlamaForCausalLM(nn.Module):
...
@@ -31,26 +34,64 @@ class LlavaLlamaForCausalLM(nn.Module):
self
.
config
.
text_config
.
hidden_size
=
config
.
hidden_size
self
.
config
.
text_config
.
hidden_size
=
config
.
hidden_size
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
config
)
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
config
)
self
.
language_model
=
LlamaForCausalLM
(
config
,
linear_method
)
self
.
language_model
=
LlamaForCausalLM
(
config
,
linear_method
)
if
"unpad"
in
getattr
(
config
,
"mm_patch_merge_type"
):
self
.
language_model
.
model
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
,
dtype
=
torch
.
float16
))
def
pad_input_ids
(
self
,
input_ids
,
pad_value
,
pt_shape
=
None
,
image_size
=
None
):
new_image_feature_len
=
self
.
image_feature_len
# now only support spatial_unpad + anyres
if
self
.
mm_patch_merge_type
.
startswith
(
"spatial"
):
height
=
width
=
self
.
num_patches_per_side
if
pt_shape
[
0
]
>
1
:
if
self
.
image_aspect_ratio
==
"anyres"
:
num_patch_width
,
num_patch_height
=
get_anyres_image_grid_shape
(
image_size
,
self
.
image_grid_pinpoints
,
self
.
vision_tower
.
config
.
image_size
,
)
if
"unpad"
in
self
.
mm_patch_merge_type
:
h
=
num_patch_height
*
height
w
=
num_patch_width
*
width
new_h
,
new_w
=
unpad_image_shape
(
h
,
w
,
image_size
)
new_image_feature_len
+=
new_h
*
(
new_w
+
1
)
def
pad_input_ids
(
self
,
input_ids
,
pad_value
):
pad_ids
=
pad_value
*
(
pad_ids
=
pad_value
*
(
(
self
.
image_feature_len
+
len
(
pad_value
))
//
len
(
pad_value
)
(
new_
image_feature_len
+
len
(
pad_value
))
//
len
(
pad_value
)
)
)
offset
=
input_ids
.
index
(
self
.
config
.
image_token_index
)
offset
=
input_ids
.
index
(
self
.
config
.
image_token_index
)
# old_len + pad_len - 1, because we need to remove image_token_id
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids
=
(
new_input_ids
=
(
input_ids
[:
offset
]
input_ids
[:
offset
]
+
pad_ids
[:
self
.
image_feature_len
]
+
pad_ids
[:
new_
image_feature_len
]
+
input_ids
[
offset
+
1
:]
+
input_ids
[
offset
+
1
:]
)
)
return
new_input_ids
,
offset
return
new_input_ids
,
offset
def
encode_images
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image_outputs
=
self
.
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature
=
image_outputs
.
hidden_states
[
self
.
vision_feature_layer
]
if
self
.
vision_feature_select_strategy
in
[
"default"
,
"patch"
]:
selected_image_feature
=
selected_image_feature
[:,
1
:]
elif
self
.
vision_feature_select_strategy
==
"full"
:
selected_image_feature
=
selected_image_feature
else
:
raise
ValueError
(
f
"Unexpected select feature strategy:
{
self
.
config
.
vision_feature_select_strategy
}
"
)
image_features
=
self
.
multi_modal_projector
(
selected_image_feature
)
return
image_features
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
pixel_values
:
Optional
[
List
[
Optional
[
np
.
array
]]]
=
None
,
pixel_values
:
Optional
[
List
[
Optional
[
np
.
array
]]]
=
None
,
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
...
@@ -75,23 +116,86 @@ class LlavaLlamaForCausalLM(nn.Module):
...
@@ -75,23 +116,86 @@ class LlavaLlamaForCausalLM(nn.Module):
device
=
self
.
vision_tower
.
device
,
device
=
self
.
vision_tower
.
device
,
)
)
image_outputs
=
self
.
vision_tower
(
########## Encode Image ########
pixel_values
,
output_hidden_states
=
True
)
if
pixel_values
.
ndim
==
5
:
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
concat_images
=
torch
.
cat
(
selected_image_feature
=
image_outputs
.
hidden_states
[
[
image
for
image
in
pixel_values
],
dim
=
0
self
.
vision_feature_layer
)
# ndim=4
]
image_features
=
self
.
encode_images
(
concat_images
)
if
self
.
vision_feature_select_strategy
in
[
"default"
,
"patch"
]:
split_sizes
=
[
image
.
shape
[
0
]
for
image
in
pixel_values
]
selected_image_feature
=
selected_image_feature
[:,
1
:]
image_features
=
torch
.
split
(
image_features
,
split_sizes
,
dim
=
0
)
elif
self
.
vision_feature_select_strategy
==
"full"
:
# hd image_features: BS, num_patch, 576, 4096
selected_image_feature
=
selected_image_feature
else
:
else
:
raise
ValueError
(
# normal pixel: BS, C=3, H=336, W=336
f
"Unexpected select feature strategy:
{
self
.
config
.
vision_feature_select_strategy
}
"
image_features
=
self
.
encode_images
(
pixel_values
)
)
# image_features: BS, 576, 4096
image_features
=
self
.
multi_modal_projector
(
selected_image_feature
)
if
self
.
mm_patch_merge_type
.
startswith
(
"spatial"
):
new_image_features
=
[]
for
image_idx
,
image_feature
in
enumerate
(
image_features
):
if
image_feature
.
shape
[
0
]
>
1
:
base_image_feature
=
image_feature
[
0
]
image_feature
=
image_feature
[
1
:]
height
=
width
=
self
.
num_patches_per_side
assert
height
*
width
==
base_image_feature
.
shape
[
0
]
if
self
.
image_aspect_ratio
==
"anyres"
:
(
num_patch_width
,
num_patch_height
,
)
=
get_anyres_image_grid_shape
(
image_sizes
[
image_idx
],
self
.
image_grid_pinpoints
,
self
.
vision_tower
.
config
.
image_size
,
)
image_feature
=
image_feature
.
view
(
num_patch_height
,
num_patch_width
,
height
,
width
,
-
1
)
else
:
raise
NotImplementedError
if
"unpad"
in
self
.
mm_patch_merge_type
:
image_feature
=
image_feature
.
permute
(
4
,
0
,
2
,
1
,
3
).
contiguous
()
image_feature
=
image_feature
.
flatten
(
1
,
2
).
flatten
(
2
,
3
)
image_feature
=
unpad_image
(
image_feature
,
image_sizes
[
image_idx
]
)
image_feature
=
torch
.
cat
(
(
image_feature
,
self
.
language_model
.
model
.
image_newline
[
:,
None
,
None
].
expand
(
*
image_feature
.
shape
[:
-
1
],
1
),
),
dim
=-
1
,
)
image_feature
=
image_feature
.
flatten
(
1
,
2
).
transpose
(
0
,
1
)
else
:
image_feature
=
image_feature
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
image_feature
=
image_feature
.
flatten
(
0
,
3
)
image_feature
=
torch
.
cat
(
(
base_image_feature
,
image_feature
),
dim
=
0
)
else
:
image_feature
=
image_feature
[
0
]
if
"unpad"
in
self
.
mm_patch_merge_type
:
image_feature
=
torch
.
cat
(
(
image_feature
,
self
.
language_model
.
model
.
image_newline
[
None
],
),
dim
=
0
,
)
new_image_features
.
append
(
image_feature
)
image_features
=
new_image_features
extend_start_loc_cpu
=
input_metadata
.
extend_start_loc
.
cpu
().
numpy
()
extend_start_loc_cpu
=
input_metadata
.
extend_start_loc
.
cpu
().
numpy
()
pt
=
0
pt
=
0
...
@@ -100,7 +204,7 @@ class LlavaLlamaForCausalLM(nn.Module):
...
@@ -100,7 +204,7 @@ class LlavaLlamaForCausalLM(nn.Module):
continue
continue
start_idx
=
extend_start_loc_cpu
[
i
]
start_idx
=
extend_start_loc_cpu
[
i
]
pad_len
,
pad_dim
=
image_features
[
pt
].
shape
pad_len
,
pad_dim
=
image_features
[
pt
].
shape
# 576, 4096
dim
=
input_embeds
.
shape
[
1
]
dim
=
input_embeds
.
shape
[
1
]
assert
(
assert
(
pad_dim
==
dim
pad_dim
==
dim
...
@@ -146,6 +250,11 @@ class LlavaLlamaForCausalLM(nn.Module):
...
@@ -146,6 +250,11 @@ class LlavaLlamaForCausalLM(nn.Module):
self
.
vision_feature_select_strategy
=
self
.
config
.
mm_vision_select_feature
self
.
vision_feature_select_strategy
=
self
.
config
.
mm_vision_select_feature
self
.
image_size
=
self
.
vision_tower
.
config
.
image_size
self
.
image_size
=
self
.
vision_tower
.
config
.
image_size
self
.
patch_size
=
self
.
vision_tower
.
config
.
patch_size
self
.
patch_size
=
self
.
vision_tower
.
config
.
patch_size
self
.
mm_patch_merge_type
=
getattr
(
self
.
config
,
"mm_patch_merge_type"
,
"flat"
)
self
.
image_aspect_ratio
=
getattr
(
self
.
config
,
"image_aspect_ratio"
,
"square"
)
self
.
image_grid_pinpoints
=
getattr
(
self
.
config
,
"image_grid_pinpoints"
,
None
)
self
.
image_feature_len
=
int
((
self
.
image_size
/
self
.
patch_size
)
**
2
)
self
.
image_feature_len
=
int
((
self
.
image_size
/
self
.
patch_size
)
**
2
)
if
self
.
vision_feature_select_strategy
==
"patch"
:
if
self
.
vision_feature_select_strategy
==
"patch"
:
pass
pass
...
@@ -159,13 +268,14 @@ class LlavaLlamaForCausalLM(nn.Module):
...
@@ -159,13 +268,14 @@ class LlavaLlamaForCausalLM(nn.Module):
projector_weights
=
{
projector_weights
=
{
"model.mm_projector.0"
:
"multi_modal_projector.linear_1"
,
"model.mm_projector.0"
:
"multi_modal_projector.linear_1"
,
"model.mm_projector.2"
:
"multi_modal_projector.linear_2"
,
"model.mm_projector.2"
:
"multi_modal_projector.linear_2"
,
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
}
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
):
# FIXME: why projector weights read two times?
# FIXME: why projector weights read two times?
if
"projector"
in
name
:
if
"projector"
in
name
or
"vision_tower"
in
name
:
for
weight_name
,
param_name
in
projector_weights
.
items
():
for
weight_name
,
param_name
in
projector_weights
.
items
():
if
weight_name
in
name
:
if
weight_name
in
name
:
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
...
@@ -180,6 +290,10 @@ class LlavaLlamaForCausalLM(nn.Module):
...
@@ -180,6 +290,10 @@ class LlavaLlamaForCausalLM(nn.Module):
monkey_path_clip_vision_embed_forward
()
monkey_path_clip_vision_embed_forward
()
@
property
def
num_patches_per_side
(
self
):
return
self
.
image_size
//
self
.
patch_size
first_call
=
True
first_call
=
True
...
...
python/sglang/srt/server.py
View file @
c6576e82
...
@@ -469,7 +469,6 @@ class Runtime:
...
@@ -469,7 +469,6 @@ class Runtime:
prompt
:
str
,
prompt
:
str
,
sampling_params
,
sampling_params
,
)
->
None
:
)
->
None
:
json_data
=
{
json_data
=
{
"text"
:
prompt
,
"text"
:
prompt
,
"sampling_params"
:
sampling_params
,
"sampling_params"
:
sampling_params
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment