Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Wan2.1_pytorch
Commits
4f71a2b0
Commit
4f71a2b0
authored
Feb 28, 2025
by
mashun1
Browse files
wan2.1
parents
Pipeline
#2434
failed with stages
in 0 seconds
Changes
82
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
481 additions
and
0 deletions
+481
-0
wan/utils/qwen_vl_utils.py
wan/utils/qwen_vl_utils.py
+363
-0
wan/utils/utils.py
wan/utils/utils.py
+118
-0
No files found.
wan/utils/qwen_vl_utils.py
0 → 100644
View file @
4f71a2b0
# Copied from https://github.com/kq-chen/qwen-vl-utils
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from
__future__
import
annotations
import
base64
import
logging
import
math
import
os
import
sys
import
time
import
warnings
from
functools
import
lru_cache
from
io
import
BytesIO
import
requests
import
torch
import
torchvision
from
packaging
import
version
from
PIL
import
Image
from
torchvision
import
io
,
transforms
from
torchvision.transforms
import
InterpolationMode
logger
=
logging
.
getLogger
(
__name__
)
IMAGE_FACTOR
=
28
MIN_PIXELS
=
4
*
28
*
28
MAX_PIXELS
=
16384
*
28
*
28
MAX_RATIO
=
200
VIDEO_MIN_PIXELS
=
128
*
28
*
28
VIDEO_MAX_PIXELS
=
768
*
28
*
28
VIDEO_TOTAL_PIXELS
=
24576
*
28
*
28
FRAME_FACTOR
=
2
FPS
=
2.0
FPS_MIN_FRAMES
=
4
FPS_MAX_FRAMES
=
768
def
round_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return
round
(
number
/
factor
)
*
factor
def
ceil_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return
math
.
ceil
(
number
/
factor
)
*
factor
def
floor_by_factor
(
number
:
int
,
factor
:
int
)
->
int
:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return
math
.
floor
(
number
/
factor
)
*
factor
def
smart_resize
(
height
:
int
,
width
:
int
,
factor
:
int
=
IMAGE_FACTOR
,
min_pixels
:
int
=
MIN_PIXELS
,
max_pixels
:
int
=
MAX_PIXELS
)
->
tuple
[
int
,
int
]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if
max
(
height
,
width
)
/
min
(
height
,
width
)
>
MAX_RATIO
:
raise
ValueError
(
f
"absolute aspect ratio must be smaller than
{
MAX_RATIO
}
, got
{
max
(
height
,
width
)
/
min
(
height
,
width
)
}
"
)
h_bar
=
max
(
factor
,
round_by_factor
(
height
,
factor
))
w_bar
=
max
(
factor
,
round_by_factor
(
width
,
factor
))
if
h_bar
*
w_bar
>
max_pixels
:
beta
=
math
.
sqrt
((
height
*
width
)
/
max_pixels
)
h_bar
=
floor_by_factor
(
height
/
beta
,
factor
)
w_bar
=
floor_by_factor
(
width
/
beta
,
factor
)
elif
h_bar
*
w_bar
<
min_pixels
:
beta
=
math
.
sqrt
(
min_pixels
/
(
height
*
width
))
h_bar
=
ceil_by_factor
(
height
*
beta
,
factor
)
w_bar
=
ceil_by_factor
(
width
*
beta
,
factor
)
return
h_bar
,
w_bar
def
fetch_image
(
ele
:
dict
[
str
,
str
|
Image
.
Image
],
size_factor
:
int
=
IMAGE_FACTOR
)
->
Image
.
Image
:
if
"image"
in
ele
:
image
=
ele
[
"image"
]
else
:
image
=
ele
[
"image_url"
]
image_obj
=
None
if
isinstance
(
image
,
Image
.
Image
):
image_obj
=
image
elif
image
.
startswith
(
"http://"
)
or
image
.
startswith
(
"https://"
):
image_obj
=
Image
.
open
(
requests
.
get
(
image
,
stream
=
True
).
raw
)
elif
image
.
startswith
(
"file://"
):
image_obj
=
Image
.
open
(
image
[
7
:])
elif
image
.
startswith
(
"data:image"
):
if
"base64,"
in
image
:
_
,
base64_data
=
image
.
split
(
"base64,"
,
1
)
data
=
base64
.
b64decode
(
base64_data
)
image_obj
=
Image
.
open
(
BytesIO
(
data
))
else
:
image_obj
=
Image
.
open
(
image
)
if
image_obj
is
None
:
raise
ValueError
(
f
"Unrecognized image input, support local path, http url, base64 and PIL.Image, got
{
image
}
"
)
image
=
image_obj
.
convert
(
"RGB"
)
## resize
if
"resized_height"
in
ele
and
"resized_width"
in
ele
:
resized_height
,
resized_width
=
smart_resize
(
ele
[
"resized_height"
],
ele
[
"resized_width"
],
factor
=
size_factor
,
)
else
:
width
,
height
=
image
.
size
min_pixels
=
ele
.
get
(
"min_pixels"
,
MIN_PIXELS
)
max_pixels
=
ele
.
get
(
"max_pixels"
,
MAX_PIXELS
)
resized_height
,
resized_width
=
smart_resize
(
height
,
width
,
factor
=
size_factor
,
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
)
image
=
image
.
resize
((
resized_width
,
resized_height
))
return
image
def
smart_nframes
(
ele
:
dict
,
total_frames
:
int
,
video_fps
:
int
|
float
,
)
->
int
:
"""calculate the number of frames for video used for model inputs.
Args:
ele (dict): a dict contains the configuration of video.
support either `fps` or `nframes`:
- nframes: the number of frames to extract for model inputs.
- fps: the fps to extract frames for model inputs.
- min_frames: the minimum number of frames of the video, only used when fps is provided.
- max_frames: the maximum number of frames of the video, only used when fps is provided.
total_frames (int): the original total number of frames of the video.
video_fps (int | float): the original fps of the video.
Raises:
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
Returns:
int: the number of frames for video used for model inputs.
"""
assert
not
(
"fps"
in
ele
and
"nframes"
in
ele
),
"Only accept either `fps` or `nframes`"
if
"nframes"
in
ele
:
nframes
=
round_by_factor
(
ele
[
"nframes"
],
FRAME_FACTOR
)
else
:
fps
=
ele
.
get
(
"fps"
,
FPS
)
min_frames
=
ceil_by_factor
(
ele
.
get
(
"min_frames"
,
FPS_MIN_FRAMES
),
FRAME_FACTOR
)
max_frames
=
floor_by_factor
(
ele
.
get
(
"max_frames"
,
min
(
FPS_MAX_FRAMES
,
total_frames
)),
FRAME_FACTOR
)
nframes
=
total_frames
/
video_fps
*
fps
nframes
=
min
(
max
(
nframes
,
min_frames
),
max_frames
)
nframes
=
round_by_factor
(
nframes
,
FRAME_FACTOR
)
if
not
(
FRAME_FACTOR
<=
nframes
and
nframes
<=
total_frames
):
raise
ValueError
(
f
"nframes should in interval [
{
FRAME_FACTOR
}
,
{
total_frames
}
], but got
{
nframes
}
."
)
return
nframes
def
_read_video_torchvision
(
ele
:
dict
,)
->
torch
.
Tensor
:
"""read video using torchvision.io.read_video
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
video_path
=
ele
[
"video"
]
if
version
.
parse
(
torchvision
.
__version__
)
<
version
.
parse
(
"0.19.0"
):
if
"http://"
in
video_path
or
"https://"
in
video_path
:
warnings
.
warn
(
"torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0."
)
if
"file://"
in
video_path
:
video_path
=
video_path
[
7
:]
st
=
time
.
time
()
video
,
audio
,
info
=
io
.
read_video
(
video_path
,
start_pts
=
ele
.
get
(
"video_start"
,
0.0
),
end_pts
=
ele
.
get
(
"video_end"
,
None
),
pts_unit
=
"sec"
,
output_format
=
"TCHW"
,
)
total_frames
,
video_fps
=
video
.
size
(
0
),
info
[
"video_fps"
]
logger
.
info
(
f
"torchvision:
{
video_path
=
}
,
{
total_frames
=
}
,
{
video_fps
=
}
, time=
{
time
.
time
()
-
st
:.
3
f
}
s"
)
nframes
=
smart_nframes
(
ele
,
total_frames
=
total_frames
,
video_fps
=
video_fps
)
idx
=
torch
.
linspace
(
0
,
total_frames
-
1
,
nframes
).
round
().
long
()
video
=
video
[
idx
]
return
video
def
is_decord_available
()
->
bool
:
import
importlib.util
return
importlib
.
util
.
find_spec
(
"decord"
)
is
not
None
def
_read_video_decord
(
ele
:
dict
,)
->
torch
.
Tensor
:
"""read video using decord.VideoReader
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
import
decord
video_path
=
ele
[
"video"
]
st
=
time
.
time
()
vr
=
decord
.
VideoReader
(
video_path
)
# TODO: support start_pts and end_pts
if
'video_start'
in
ele
or
'video_end'
in
ele
:
raise
NotImplementedError
(
"not support start_pts and end_pts in decord for now."
)
total_frames
,
video_fps
=
len
(
vr
),
vr
.
get_avg_fps
()
logger
.
info
(
f
"decord:
{
video_path
=
}
,
{
total_frames
=
}
,
{
video_fps
=
}
, time=
{
time
.
time
()
-
st
:.
3
f
}
s"
)
nframes
=
smart_nframes
(
ele
,
total_frames
=
total_frames
,
video_fps
=
video_fps
)
idx
=
torch
.
linspace
(
0
,
total_frames
-
1
,
nframes
).
round
().
long
().
tolist
()
video
=
vr
.
get_batch
(
idx
).
asnumpy
()
video
=
torch
.
tensor
(
video
).
permute
(
0
,
3
,
1
,
2
)
# Convert to TCHW format
return
video
VIDEO_READER_BACKENDS
=
{
"decord"
:
_read_video_decord
,
"torchvision"
:
_read_video_torchvision
,
}
FORCE_QWENVL_VIDEO_READER
=
os
.
getenv
(
"FORCE_QWENVL_VIDEO_READER"
,
None
)
@
lru_cache
(
maxsize
=
1
)
def
get_video_reader_backend
()
->
str
:
if
FORCE_QWENVL_VIDEO_READER
is
not
None
:
video_reader_backend
=
FORCE_QWENVL_VIDEO_READER
elif
is_decord_available
():
video_reader_backend
=
"decord"
else
:
video_reader_backend
=
"torchvision"
print
(
f
"qwen-vl-utils using
{
video_reader_backend
}
to read video."
,
file
=
sys
.
stderr
)
return
video_reader_backend
def
fetch_video
(
ele
:
dict
,
image_factor
:
int
=
IMAGE_FACTOR
)
->
torch
.
Tensor
|
list
[
Image
.
Image
]:
if
isinstance
(
ele
[
"video"
],
str
):
video_reader_backend
=
get_video_reader_backend
()
video
=
VIDEO_READER_BACKENDS
[
video_reader_backend
](
ele
)
nframes
,
_
,
height
,
width
=
video
.
shape
min_pixels
=
ele
.
get
(
"min_pixels"
,
VIDEO_MIN_PIXELS
)
total_pixels
=
ele
.
get
(
"total_pixels"
,
VIDEO_TOTAL_PIXELS
)
max_pixels
=
max
(
min
(
VIDEO_MAX_PIXELS
,
total_pixels
/
nframes
*
FRAME_FACTOR
),
int
(
min_pixels
*
1.05
))
max_pixels
=
ele
.
get
(
"max_pixels"
,
max_pixels
)
if
"resized_height"
in
ele
and
"resized_width"
in
ele
:
resized_height
,
resized_width
=
smart_resize
(
ele
[
"resized_height"
],
ele
[
"resized_width"
],
factor
=
image_factor
,
)
else
:
resized_height
,
resized_width
=
smart_resize
(
height
,
width
,
factor
=
image_factor
,
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
)
video
=
transforms
.
functional
.
resize
(
video
,
[
resized_height
,
resized_width
],
interpolation
=
InterpolationMode
.
BICUBIC
,
antialias
=
True
,
).
float
()
return
video
else
:
assert
isinstance
(
ele
[
"video"
],
(
list
,
tuple
))
process_info
=
ele
.
copy
()
process_info
.
pop
(
"type"
,
None
)
process_info
.
pop
(
"video"
,
None
)
images
=
[
fetch_image
({
"image"
:
video_element
,
**
process_info
},
size_factor
=
image_factor
)
for
video_element
in
ele
[
"video"
]
]
nframes
=
ceil_by_factor
(
len
(
images
),
FRAME_FACTOR
)
if
len
(
images
)
<
nframes
:
images
.
extend
([
images
[
-
1
]]
*
(
nframes
-
len
(
images
)))
return
images
def
extract_vision_info
(
conversations
:
list
[
dict
]
|
list
[
list
[
dict
]])
->
list
[
dict
]:
vision_infos
=
[]
if
isinstance
(
conversations
[
0
],
dict
):
conversations
=
[
conversations
]
for
conversation
in
conversations
:
for
message
in
conversation
:
if
isinstance
(
message
[
"content"
],
list
):
for
ele
in
message
[
"content"
]:
if
(
"image"
in
ele
or
"image_url"
in
ele
or
"video"
in
ele
or
ele
[
"type"
]
in
(
"image"
,
"image_url"
,
"video"
)):
vision_infos
.
append
(
ele
)
return
vision_infos
def
process_vision_info
(
conversations
:
list
[
dict
]
|
list
[
list
[
dict
]],
)
->
tuple
[
list
[
Image
.
Image
]
|
None
,
list
[
torch
.
Tensor
|
list
[
Image
.
Image
]]
|
None
]:
vision_infos
=
extract_vision_info
(
conversations
)
## Read images or videos
image_inputs
=
[]
video_inputs
=
[]
for
vision_info
in
vision_infos
:
if
"image"
in
vision_info
or
"image_url"
in
vision_info
:
image_inputs
.
append
(
fetch_image
(
vision_info
))
elif
"video"
in
vision_info
:
video_inputs
.
append
(
fetch_video
(
vision_info
))
else
:
raise
ValueError
(
"image, image_url or video should in content."
)
if
len
(
image_inputs
)
==
0
:
image_inputs
=
None
if
len
(
video_inputs
)
==
0
:
video_inputs
=
None
return
image_inputs
,
video_inputs
wan/utils/utils.py
0 → 100644
View file @
4f71a2b0
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import
argparse
import
binascii
import
os
import
os.path
as
osp
import
imageio
import
torch
import
torchvision
__all__
=
[
'cache_video'
,
'cache_image'
,
'str2bool'
]
def
rand_name
(
length
=
8
,
suffix
=
''
):
name
=
binascii
.
b2a_hex
(
os
.
urandom
(
length
)).
decode
(
'utf-8'
)
if
suffix
:
if
not
suffix
.
startswith
(
'.'
):
suffix
=
'.'
+
suffix
name
+=
suffix
return
name
def
cache_video
(
tensor
,
save_file
=
None
,
fps
=
30
,
suffix
=
'.mp4'
,
nrow
=
8
,
normalize
=
True
,
value_range
=
(
-
1
,
1
),
retry
=
5
):
# cache file
cache_file
=
osp
.
join
(
'/tmp'
,
rand_name
(
suffix
=
suffix
))
if
save_file
is
None
else
save_file
# save to cache
error
=
None
for
_
in
range
(
retry
):
try
:
# preprocess
tensor
=
tensor
.
clamp
(
min
(
value_range
),
max
(
value_range
))
tensor
=
torch
.
stack
([
torchvision
.
utils
.
make_grid
(
u
,
nrow
=
nrow
,
normalize
=
normalize
,
value_range
=
value_range
)
for
u
in
tensor
.
unbind
(
2
)
],
dim
=
1
).
permute
(
1
,
2
,
3
,
0
)
tensor
=
(
tensor
*
255
).
type
(
torch
.
uint8
).
cpu
()
# write video
writer
=
imageio
.
get_writer
(
cache_file
,
fps
=
fps
,
codec
=
'libx264'
,
quality
=
8
)
for
frame
in
tensor
.
numpy
():
writer
.
append_data
(
frame
)
writer
.
close
()
return
cache_file
except
Exception
as
e
:
error
=
e
continue
else
:
print
(
f
'cache_video failed, error:
{
error
}
'
,
flush
=
True
)
return
None
def
cache_image
(
tensor
,
save_file
,
nrow
=
8
,
normalize
=
True
,
value_range
=
(
-
1
,
1
),
retry
=
5
):
# cache file
suffix
=
osp
.
splitext
(
save_file
)[
1
]
if
suffix
.
lower
()
not
in
[
'.jpg'
,
'.jpeg'
,
'.png'
,
'.tiff'
,
'.gif'
,
'.webp'
]:
suffix
=
'.png'
# save to cache
error
=
None
for
_
in
range
(
retry
):
try
:
tensor
=
tensor
.
clamp
(
min
(
value_range
),
max
(
value_range
))
torchvision
.
utils
.
save_image
(
tensor
,
save_file
,
nrow
=
nrow
,
normalize
=
normalize
,
value_range
=
value_range
)
return
save_file
except
Exception
as
e
:
error
=
e
continue
def
str2bool
(
v
):
"""
Convert a string to a boolean.
Supported true values: 'yes', 'true', 't', 'y', '1'
Supported false values: 'no', 'false', 'f', 'n', '0'
Args:
v (str): String to convert.
Returns:
bool: Converted boolean value.
Raises:
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
"""
if
isinstance
(
v
,
bool
):
return
v
v_lower
=
v
.
lower
()
if
v_lower
in
(
'yes'
,
'true'
,
't'
,
'y'
,
'1'
):
return
True
elif
v_lower
in
(
'no'
,
'false'
,
'f'
,
'n'
,
'0'
):
return
False
else
:
raise
argparse
.
ArgumentTypeError
(
'Boolean value expected (True/False)'
)
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment