Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0147f940
Unverified
Commit
0147f940
authored
Jan 25, 2024
by
shiyi.c_98
Committed by
GitHub
Jan 25, 2024
Browse files
fix batch error for llava-hd (#98)
parent
23950056
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
8 deletions
+12
-8
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+12
-8
No files found.
python/sglang/srt/models/llava.py
View file @
0147f940
...
@@ -112,24 +112,28 @@ class LlavaLlamaForCausalLM(nn.Module):
...
@@ -112,24 +112,28 @@ class LlavaLlamaForCausalLM(nn.Module):
need_vision
=
need_vision
&
has_pixel
need_vision
=
need_vision
&
has_pixel
if
need_vision
.
any
():
if
need_vision
.
any
():
pixel_values
=
torch
.
tensor
(
pixel_values
=
[
pixel_values
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
np
.
array
([
pixel_values
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]),
image_sizes
=
[
image_sizes
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
device
=
self
.
vision_tower
.
device
,
)
########## Encode Image ########
########## Encode Image ########
if
pixel_values
.
ndim
==
5
:
if
pixel_values
[
0
]
.
ndim
==
4
:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
concat_images
=
torch
.
cat
(
np
.
concatenate
(
pixel_values
,
axis
=
0
)
[
image
for
image
in
pixel_values
],
dim
=
0
# ndim=4
)
# ndim=4
concat_images
=
torch
.
tensor
(
np
.
concatenate
(
pixel_values
,
axis
=
0
),
device
=
self
.
vision_tower
.
device
,
)
image_features
=
self
.
encode_images
(
concat_images
)
image_features
=
self
.
encode_images
(
concat_images
)
split_sizes
=
[
image
.
shape
[
0
]
for
image
in
pixel_values
]
split_sizes
=
[
image
.
shape
[
0
]
for
image
in
pixel_values
]
image_features
=
torch
.
split
(
image_features
,
split_sizes
,
dim
=
0
)
image_features
=
torch
.
split
(
image_features
,
split_sizes
,
dim
=
0
)
# hd image_features: BS, num_patch, 576, 4096
# hd image_features: BS, num_patch, 576, 4096
else
:
else
:
# normal pixel: BS, C=3, H=336, W=336
# normal pixel: BS, C=3, H=336, W=336
pixel_values
=
torch
.
tensor
(
np
.
array
(
pixel_values
),
device
=
self
.
vision_tower
.
device
)
image_features
=
self
.
encode_images
(
pixel_values
)
image_features
=
self
.
encode_images
(
pixel_values
)
# image_features: BS, 576, 4096
# image_features: BS, 576, 4096
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment