Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
86f7f033
Commit
86f7f033
authored
Apr 08, 2025
by
helloyongyang
Committed by
Yang Yong(雍洋)
Apr 08, 2025
Browse files
support hunyuan i2v
parent
18532cd2
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
538 additions
and
106 deletions
+538
-106
lightx2v/__main__.py
lightx2v/__main__.py
+105
-21
lightx2v/text2v/models/networks/hunyuan/infer/pre_infer.py
lightx2v/text2v/models/networks/hunyuan/infer/pre_infer.py
+14
-1
lightx2v/text2v/models/networks/hunyuan/infer/transformer_infer.py
...text2v/models/networks/hunyuan/infer/transformer_infer.py
+56
-55
lightx2v/text2v/models/networks/hunyuan/model.py
lightx2v/text2v/models/networks/hunyuan/model.py
+7
-2
lightx2v/text2v/models/schedulers/hunyuan/scheduler.py
lightx2v/text2v/models/schedulers/hunyuan/scheduler.py
+143
-24
lightx2v/text2v/models/text_encoders/hf/llava/__init__.py
lightx2v/text2v/models/text_encoders/hf/llava/__init__.py
+0
-0
lightx2v/text2v/models/text_encoders/hf/llava/model.py
lightx2v/text2v/models/text_encoders/hf/llava/model.py
+162
-0
lightx2v/text2v/models/video_encoders/hf/autoencoder_kl_causal_3d/model.py
...odels/video_encoders/hf/autoencoder_kl_causal_3d/model.py
+13
-3
scripts/run_hunyuan_i2v.sh
scripts/run_hunyuan_i2v.sh
+38
-0
No files found.
lightx2v/__main__.py
View file @
86f7f033
...
@@ -5,12 +5,14 @@ import os
...
@@ -5,12 +5,14 @@ import os
import
time
import
time
import
gc
import
gc
import
json
import
json
import
torchvision
import
torchvision.transforms.functional
as
TF
import
torchvision.transforms.functional
as
TF
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
from
lightx2v.text2v.models.text_encoders.hf.llama.model
import
TextEncoderHFLlamaModel
from
lightx2v.text2v.models.text_encoders.hf.llama.model
import
TextEncoderHFLlamaModel
from
lightx2v.text2v.models.text_encoders.hf.clip.model
import
TextEncoderHFClipModel
from
lightx2v.text2v.models.text_encoders.hf.clip.model
import
TextEncoderHFClipModel
from
lightx2v.text2v.models.text_encoders.hf.t5.model
import
T5EncoderModel
from
lightx2v.text2v.models.text_encoders.hf.t5.model
import
T5EncoderModel
from
lightx2v.text2v.models.text_encoders.hf.llava.model
import
TextEncoderHFLlavaModel
from
lightx2v.text2v.models.schedulers.hunyuan.scheduler
import
HunyuanScheduler
from
lightx2v.text2v.models.schedulers.hunyuan.scheduler
import
HunyuanScheduler
from
lightx2v.text2v.models.schedulers.hunyuan.feature_caching.scheduler
import
HunyuanSchedulerFeatureCaching
from
lightx2v.text2v.models.schedulers.hunyuan.feature_caching.scheduler
import
HunyuanSchedulerFeatureCaching
...
@@ -38,11 +40,14 @@ def load_models(args, model_config):
...
@@ -38,11 +40,14 @@ def load_models(args, model_config):
init_device
=
torch
.
device
(
"cuda"
)
init_device
=
torch
.
device
(
"cuda"
)
if
args
.
model_cls
==
"hunyuan"
:
if
args
.
model_cls
==
"hunyuan"
:
text_encoder_1
=
TextEncoderHFLlamaModel
(
os
.
path
.
join
(
args
.
model_path
,
"text_encoder"
),
init_device
)
if
args
.
task
==
"t2v"
:
text_encoder_1
=
TextEncoderHFLlamaModel
(
os
.
path
.
join
(
args
.
model_path
,
"text_encoder"
),
init_device
)
else
:
text_encoder_1
=
TextEncoderHFLlavaModel
(
os
.
path
.
join
(
args
.
model_path
,
"text_encoder_i2v"
),
init_device
)
text_encoder_2
=
TextEncoderHFClipModel
(
os
.
path
.
join
(
args
.
model_path
,
"text_encoder_2"
),
init_device
)
text_encoder_2
=
TextEncoderHFClipModel
(
os
.
path
.
join
(
args
.
model_path
,
"text_encoder_2"
),
init_device
)
text_encoders
=
[
text_encoder_1
,
text_encoder_2
]
text_encoders
=
[
text_encoder_1
,
text_encoder_2
]
model
=
HunyuanModel
(
args
.
model_path
,
model_config
,
init_device
)
model
=
HunyuanModel
(
args
.
model_path
,
model_config
,
init_device
,
args
)
vae_model
=
VideoEncoderKLCausal3DModel
(
args
.
model_path
,
dtype
=
torch
.
float16
,
device
=
init_device
)
vae_model
=
VideoEncoderKLCausal3DModel
(
args
.
model_path
,
dtype
=
torch
.
float16
,
device
=
init_device
,
args
=
args
)
elif
args
.
model_cls
==
"wan2.1"
:
elif
args
.
model_cls
==
"wan2.1"
:
text_encoder
=
T5EncoderModel
(
text_encoder
=
T5EncoderModel
(
...
@@ -69,16 +74,26 @@ def load_models(args, model_config):
...
@@ -69,16 +74,26 @@ def load_models(args, model_config):
return
model
,
text_encoders
,
vae_model
,
image_encoder
return
model
,
text_encoders
,
vae_model
,
image_encoder
def
set_target_shape
(
args
):
def
set_target_shape
(
args
,
image_encoder_output
):
if
args
.
model_cls
==
"hunyuan"
:
if
args
.
model_cls
==
"hunyuan"
:
vae_scale_factor
=
2
**
(
4
-
1
)
if
args
.
task
==
"t2v"
:
args
.
target_shape
=
(
vae_scale_factor
=
2
**
(
4
-
1
)
1
,
args
.
target_shape
=
(
16
,
1
,
(
args
.
target_video_length
-
1
)
//
4
+
1
,
16
,
int
(
args
.
target_height
)
//
vae_scale_factor
,
(
args
.
target_video_length
-
1
)
//
4
+
1
,
int
(
args
.
target_width
)
//
vae_scale_factor
,
int
(
args
.
target_height
)
//
vae_scale_factor
,
)
int
(
args
.
target_width
)
//
vae_scale_factor
,
)
elif
args
.
task
==
"i2v"
:
vae_scale_factor
=
2
**
(
4
-
1
)
args
.
target_shape
=
(
1
,
16
,
(
args
.
target_video_length
-
1
)
//
4
+
1
,
int
(
image_encoder_output
[
"target_height"
])
//
vae_scale_factor
,
int
(
image_encoder_output
[
"target_width"
])
//
vae_scale_factor
,
)
elif
args
.
model_cls
==
"wan2.1"
:
elif
args
.
model_cls
==
"wan2.1"
:
if
args
.
task
==
"i2v"
:
if
args
.
task
==
"i2v"
:
args
.
target_shape
=
(
16
,
21
,
args
.
lat_h
,
args
.
lat_w
)
args
.
target_shape
=
(
16
,
21
,
args
.
lat_h
,
args
.
lat_w
)
...
@@ -91,9 +106,75 @@ def set_target_shape(args):
...
@@ -91,9 +106,75 @@ def set_target_shape(args):
)
)
def
generate_crop_size_list
(
base_size
=
256
,
patch_size
=
32
,
max_ratio
=
4.0
):
num_patches
=
round
((
base_size
/
patch_size
)
**
2
)
assert
max_ratio
>=
1.0
crop_size_list
=
[]
wp
,
hp
=
num_patches
,
1
while
wp
>
0
:
if
max
(
wp
,
hp
)
/
min
(
wp
,
hp
)
<=
max_ratio
:
crop_size_list
.
append
((
wp
*
patch_size
,
hp
*
patch_size
))
if
(
hp
+
1
)
*
wp
<=
num_patches
:
hp
+=
1
else
:
wp
-=
1
return
crop_size_list
def
get_closest_ratio
(
height
:
float
,
width
:
float
,
ratios
:
list
,
buckets
:
list
):
aspect_ratio
=
float
(
height
)
/
float
(
width
)
diff_ratios
=
ratios
-
aspect_ratio
if
aspect_ratio
>=
1
:
indices
=
[(
index
,
x
)
for
index
,
x
in
enumerate
(
diff_ratios
)
if
x
<=
0
]
else
:
indices
=
[(
index
,
x
)
for
index
,
x
in
enumerate
(
diff_ratios
)
if
x
>
0
]
closest_ratio_id
=
min
(
indices
,
key
=
lambda
pair
:
abs
(
pair
[
1
]))[
0
]
closest_size
=
buckets
[
closest_ratio_id
]
closest_ratio
=
ratios
[
closest_ratio_id
]
return
closest_size
,
closest_ratio
def
run_image_encoder
(
args
,
image_encoder
,
vae_model
):
def
run_image_encoder
(
args
,
image_encoder
,
vae_model
):
if
args
.
model_cls
==
"hunyuan"
:
if
args
.
model_cls
==
"hunyuan"
:
return
None
img
=
Image
.
open
(
args
.
image_path
).
convert
(
"RGB"
)
origin_size
=
img
.
size
i2v_resolution
=
"720p"
if
i2v_resolution
==
"720p"
:
bucket_hw_base_size
=
960
elif
i2v_resolution
==
"540p"
:
bucket_hw_base_size
=
720
elif
i2v_resolution
==
"360p"
:
bucket_hw_base_size
=
480
else
:
raise
ValueError
(
f
"i2v_resolution:
{
i2v_resolution
}
must be in [360p, 540p, 720p]"
)
crop_size_list
=
generate_crop_size_list
(
bucket_hw_base_size
,
32
)
aspect_ratios
=
np
.
array
([
round
(
float
(
h
)
/
float
(
w
),
5
)
for
h
,
w
in
crop_size_list
])
closest_size
,
closest_ratio
=
get_closest_ratio
(
origin_size
[
1
],
origin_size
[
0
],
aspect_ratios
,
crop_size_list
)
resize_param
=
min
(
closest_size
)
center_crop_param
=
closest_size
ref_image_transform
=
torchvision
.
transforms
.
Compose
(
[
torchvision
.
transforms
.
Resize
(
resize_param
),
torchvision
.
transforms
.
CenterCrop
(
center_crop_param
),
torchvision
.
transforms
.
ToTensor
(),
torchvision
.
transforms
.
Normalize
([
0.5
],
[
0.5
])]
)
semantic_image_pixel_values
=
[
ref_image_transform
(
img
)]
semantic_image_pixel_values
=
torch
.
cat
(
semantic_image_pixel_values
).
unsqueeze
(
0
).
unsqueeze
(
2
).
to
(
torch
.
float16
).
to
(
torch
.
device
(
"cuda"
))
img_latents
=
vae_model
.
encode
(
semantic_image_pixel_values
,
args
).
mode
()
scaling_factor
=
0.476986
img_latents
.
mul_
(
scaling_factor
)
target_height
,
target_width
=
closest_size
return
{
"img"
:
img
,
"img_latents"
:
img_latents
,
"target_height"
:
target_height
,
"target_width"
:
target_width
}
elif
args
.
model_cls
==
"wan2.1"
:
elif
args
.
model_cls
==
"wan2.1"
:
img
=
Image
.
open
(
args
.
image_path
).
convert
(
"RGB"
)
img
=
Image
.
open
(
args
.
image_path
).
convert
(
"RGB"
)
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
...
@@ -124,11 +205,14 @@ def run_image_encoder(args, image_encoder, vae_model):
...
@@ -124,11 +205,14 @@ def run_image_encoder(args, image_encoder, vae_model):
raise
NotImplementedError
(
f
"Unsupported model class:
{
args
.
model_cls
}
"
)
raise
NotImplementedError
(
f
"Unsupported model class:
{
args
.
model_cls
}
"
)
def
run_text_encoder
(
args
,
text
,
text_encoders
,
model_config
):
def
run_text_encoder
(
args
,
text
,
text_encoders
,
model_config
,
image_encoder_output
):
text_encoder_output
=
{}
text_encoder_output
=
{}
if
args
.
model_cls
==
"hunyuan"
:
if
args
.
model_cls
==
"hunyuan"
:
for
i
,
encoder
in
enumerate
(
text_encoders
):
for
i
,
encoder
in
enumerate
(
text_encoders
):
text_state
,
attention_mask
=
encoder
.
infer
(
text
,
args
)
if
args
.
task
==
"i2v"
and
i
==
0
:
text_state
,
attention_mask
=
encoder
.
infer
(
text
,
image_encoder_output
[
"img"
],
args
)
else
:
text_state
,
attention_mask
=
encoder
.
infer
(
text
,
args
)
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_text_states"
]
=
text_state
.
to
(
dtype
=
torch
.
bfloat16
)
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_text_states"
]
=
text_state
.
to
(
dtype
=
torch
.
bfloat16
)
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_attention_mask"
]
=
attention_mask
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_attention_mask"
]
=
attention_mask
...
@@ -145,12 +229,12 @@ def run_text_encoder(args, text, text_encoders, model_config):
...
@@ -145,12 +229,12 @@ def run_text_encoder(args, text, text_encoders, model_config):
return
text_encoder_output
return
text_encoder_output
def
init_scheduler
(
args
):
def
init_scheduler
(
args
,
image_encoder_output
):
if
args
.
model_cls
==
"hunyuan"
:
if
args
.
model_cls
==
"hunyuan"
:
if
args
.
feature_caching
==
"NoCaching"
:
if
args
.
feature_caching
==
"NoCaching"
:
scheduler
=
HunyuanScheduler
(
args
)
scheduler
=
HunyuanScheduler
(
args
,
image_encoder_output
)
elif
args
.
feature_caching
==
"TaylorSeer"
:
elif
args
.
feature_caching
==
"TaylorSeer"
:
scheduler
=
HunyuanSchedulerFeatureCaching
(
args
)
scheduler
=
HunyuanSchedulerFeatureCaching
(
args
,
image_encoder_output
)
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
args
.
feature_caching
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
args
.
feature_caching
}
"
)
...
@@ -269,10 +353,10 @@ if __name__ == "__main__":
...
@@ -269,10 +353,10 @@ if __name__ == "__main__":
else
:
else
:
image_encoder_output
=
{
"clip_encoder_out"
:
None
,
"vae_encode_out"
:
None
}
image_encoder_output
=
{
"clip_encoder_out"
:
None
,
"vae_encode_out"
:
None
}
text_encoder_output
=
run_text_encoder
(
args
,
args
.
prompt
,
text_encoders
,
model_config
)
text_encoder_output
=
run_text_encoder
(
args
,
args
.
prompt
,
text_encoders
,
model_config
,
image_encoder_output
)
set_target_shape
(
args
)
set_target_shape
(
args
,
image_encoder_output
)
scheduler
=
init_scheduler
(
args
)
scheduler
=
init_scheduler
(
args
,
image_encoder_output
)
model
.
set_scheduler
(
scheduler
)
model
.
set_scheduler
(
scheduler
)
...
...
lightx2v/text2v/models/networks/hunyuan/infer/pre_infer.py
View file @
86f7f033
...
@@ -8,12 +8,23 @@ class HunyuanPreInfer:
...
@@ -8,12 +8,23 @@ class HunyuanPreInfer:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
heads_num
=
24
self
.
heads_num
=
24
def
infer
(
self
,
weights
,
x
,
t
,
text_states
,
text_mask
,
text_states_2
,
freqs_cos
,
freqs_sin
,
guidance
):
def
infer
(
self
,
weights
,
x
,
t
,
text_states
,
text_mask
,
text_states_2
,
freqs_cos
,
freqs_sin
,
guidance
,
img_latents
=
None
):
if
img_latents
is
not
None
:
token_replace_t
=
torch
.
zeros_like
(
t
)
token_replace_vec
=
self
.
infer_time_in
(
weights
,
token_replace_t
)
th
=
x
.
shape
[
-
2
]
//
2
tw
=
x
.
shape
[
-
1
]
//
2
frist_frame_token_num
=
th
*
tw
time_out
=
self
.
infer_time_in
(
weights
,
t
)
time_out
=
self
.
infer_time_in
(
weights
,
t
)
img_out
=
self
.
infer_img_in
(
weights
,
x
)
img_out
=
self
.
infer_img_in
(
weights
,
x
)
infer_text_out
=
self
.
infer_text_in
(
weights
,
text_states
,
text_mask
,
t
)
infer_text_out
=
self
.
infer_text_in
(
weights
,
text_states
,
text_mask
,
t
)
infer_vector_out
=
self
.
infer_vector_in
(
weights
,
text_states_2
)
infer_vector_out
=
self
.
infer_vector_in
(
weights
,
text_states_2
)
vec
=
time_out
+
infer_vector_out
vec
=
time_out
+
infer_vector_out
if
img_latents
is
not
None
:
token_replace_vec
=
token_replace_vec
+
infer_vector_out
guidance_out
=
self
.
infer_guidance_in
(
weights
,
guidance
)
guidance_out
=
self
.
infer_guidance_in
(
weights
,
guidance
)
vec
=
vec
+
guidance_out
vec
=
vec
+
guidance_out
...
@@ -32,6 +43,8 @@ class HunyuanPreInfer:
...
@@ -32,6 +43,8 @@ class HunyuanPreInfer:
cu_seqlens_qkv
[
2
*
i
+
2
]
=
s2
cu_seqlens_qkv
[
2
*
i
+
2
]
=
s2
max_seqlen_qkv
=
img_seq_len
+
txt_seq_len
max_seqlen_qkv
=
img_seq_len
+
txt_seq_len
if
img_latents
is
not
None
:
return
img_out
[
0
],
infer_text_out
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
(
freqs_cos
,
freqs_sin
),
token_replace_vec
,
frist_frame_token_num
return
img_out
[
0
],
infer_text_out
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
(
freqs_cos
,
freqs_sin
)
return
img_out
[
0
],
infer_text_out
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
(
freqs_cos
,
freqs_sin
)
def
infer_time_in
(
self
,
weights
,
t
):
def
infer_time_in
(
self
,
weights
,
t
):
...
...
lightx2v/text2v/models/networks/hunyuan/infer/transformer_infer.py
View file @
86f7f033
...
@@ -25,10 +25,10 @@ class HunyuanTransformerInfer:
...
@@ -25,10 +25,10 @@ class HunyuanTransformerInfer:
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
return
self
.
infer_func
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
)
return
self
.
infer_func
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
def
_infer_with_offload
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
def
_infer_with_offload
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
txt_seq_len
=
txt
.
shape
[
0
]
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
...
@@ -75,38 +75,22 @@ class HunyuanTransformerInfer:
...
@@ -75,38 +75,22 @@ class HunyuanTransformerInfer:
img
=
x
[:
img_seq_len
,
...]
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
return
img
,
vec
def
_infer_without_offload
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
def
_infer_without_offload
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
txt_seq_len
=
txt
.
shape
[
0
]
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
for
i
in
range
(
self
.
double_blocks_num
):
for
i
in
range
(
self
.
double_blocks_num
):
img
,
txt
=
self
.
infer_double_block
(
img
,
txt
=
self
.
infer_double_block
(
weights
.
double_blocks_weights
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
weights
.
double_blocks_weights
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
for
i
in
range
(
self
.
single_blocks_num
):
for
i
in
range
(
self
.
single_blocks_num
):
x
=
self
.
infer_single_block
(
x
=
self
.
infer_single_block
(
weights
.
single_blocks_weights
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
weights
.
single_blocks_weights
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
)
img
=
x
[:
img_seq_len
,
...]
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
return
img
,
vec
def
infer_double_block
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
def
infer_double_block
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
vec_silu
=
torch
.
nn
.
functional
.
silu
(
vec
)
vec_silu
=
torch
.
nn
.
functional
.
silu
(
vec
)
img_mod_out
=
weights
.
img_mod
.
apply
(
vec_silu
)
img_mod_out
=
weights
.
img_mod
.
apply
(
vec_silu
)
...
@@ -119,6 +103,13 @@ class HunyuanTransformerInfer:
...
@@ -119,6 +103,13 @@ class HunyuanTransformerInfer:
img_mod2_gate
,
img_mod2_gate
,
)
=
img_mod_out
.
chunk
(
6
,
dim
=-
1
)
)
=
img_mod_out
.
chunk
(
6
,
dim
=-
1
)
if
token_replace_vec
is
not
None
:
token_replace_vec_silu
=
torch
.
nn
.
functional
.
silu
(
token_replace_vec
)
token_replace_vec_img_mod_out
=
weights
.
img_mod
.
apply
(
token_replace_vec_silu
)
(
tr_img_mod1_shift
,
tr_img_mod1_scale
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
)
=
token_replace_vec_img_mod_out
.
chunk
(
6
,
dim
=-
1
)
else
:
(
tr_img_mod1_shift
,
tr_img_mod1_scale
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
)
=
None
,
None
,
None
,
None
,
None
,
None
txt_mod_out
=
weights
.
txt_mod
.
apply
(
vec_silu
)
txt_mod_out
=
weights
.
txt_mod
.
apply
(
vec_silu
)
(
(
txt_mod1_shift
,
txt_mod1_shift
,
...
@@ -129,7 +120,7 @@ class HunyuanTransformerInfer:
...
@@ -129,7 +120,7 @@ class HunyuanTransformerInfer:
txt_mod2_gate
,
txt_mod2_gate
,
)
=
txt_mod_out
.
chunk
(
6
,
dim
=-
1
)
)
=
txt_mod_out
.
chunk
(
6
,
dim
=-
1
)
img_q
,
img_k
,
img_v
=
self
.
infer_double_block_img_pre_atten
(
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
freqs_cis
)
img_q
,
img_k
,
img_v
=
self
.
infer_double_block_img_pre_atten
(
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
tr_img_mod1_scale
,
tr_img_mod1_shift
,
frist_frame_token_num
,
freqs_cis
)
txt_q
,
txt_k
,
txt_v
=
self
.
infer_double_block_txt_pre_atten
(
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
)
txt_q
,
txt_k
,
txt_v
=
self
.
infer_double_block_txt_pre_atten
(
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
...
@@ -162,28 +153,19 @@ class HunyuanTransformerInfer:
...
@@ -162,28 +153,19 @@ class HunyuanTransformerInfer:
img_attn
,
txt_attn
=
attn
[:
img
.
shape
[
0
]],
attn
[
img
.
shape
[
0
]
:]
img_attn
,
txt_attn
=
attn
[:
img
.
shape
[
0
]],
attn
[
img
.
shape
[
0
]
:]
img
=
self
.
infer_double_block_img_post_atten
(
img
=
self
.
infer_double_block_img_post_atten
(
weights
,
weights
,
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
frist_frame_token_num
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
)
txt
=
self
.
infer_double_block_txt_post_atten
(
weights
,
txt
,
txt_attn
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
)
txt
=
self
.
infer_double_block_txt_post_atten
(
weights
,
txt
,
txt_attn
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
)
return
img
,
txt
return
img
,
txt
def
infer_double_block_img_pre_atten
(
self
,
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
freqs_cis
):
def
infer_double_block_img_pre_atten
(
self
,
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
tr_img_mod1_scale
,
tr_img_mod1_shift
,
frist_frame_token_num
,
freqs_cis
):
img_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
img_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
img_modulated
=
img_modulated
*
(
1
+
img_mod1_scale
)
+
img_mod1_shift
if
tr_img_mod1_scale
is
not
None
:
x_zero
=
img_modulated
[:
frist_frame_token_num
]
*
(
1
+
tr_img_mod1_scale
)
+
tr_img_mod1_shift
x_orig
=
img_modulated
[
frist_frame_token_num
:]
*
(
1
+
img_mod1_scale
)
+
img_mod1_shift
img_modulated
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
img_modulated
=
img_modulated
*
(
1
+
img_mod1_scale
)
+
img_mod1_shift
img_qkv
=
weights
.
img_attn_qkv
.
apply
(
img_modulated
)
img_qkv
=
weights
.
img_attn_qkv
.
apply
(
img_modulated
)
img_q
,
img_k
,
img_v
=
rearrange
(
img_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
img_q
,
img_k
,
img_v
=
rearrange
(
img_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
...
@@ -206,21 +188,24 @@ class HunyuanTransformerInfer:
...
@@ -206,21 +188,24 @@ class HunyuanTransformerInfer:
return
txt_q
,
txt_k
,
txt_v
return
txt_q
,
txt_k
,
txt_v
def
infer_double_block_img_post_atten
(
def
infer_double_block_img_post_atten
(
self
,
self
,
weights
,
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
frist_frame_token_num
weights
,
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
):
):
out
=
weights
.
img_attn_proj
.
apply
(
img_attn
)
out
=
weights
.
img_attn_proj
.
apply
(
img_attn
)
out
=
out
*
img_mod1_gate
if
tr_img_mod1_gate
is
not
None
:
x_zero
=
out
[:
frist_frame_token_num
]
*
tr_img_mod1_gate
x_orig
=
out
[
frist_frame_token_num
:]
*
img_mod1_gate
out
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
out
=
out
*
img_mod1_gate
img
=
img
+
out
img
=
img
+
out
out
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
out
*
(
1
+
img_mod2_scale
)
+
img_mod2_shift
if
tr_img_mod1_gate
is
not
None
:
x_zero
=
out
[:
frist_frame_token_num
]
*
(
1
+
tr_img_mod2_scale
)
+
tr_img_mod2_shift
x_orig
=
out
[
frist_frame_token_num
:]
*
(
1
+
img_mod2_scale
)
+
img_mod2_shift
out
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
out
=
out
*
(
1
+
img_mod2_scale
)
+
img_mod2_shift
out
=
weights
.
img_mlp_fc1
.
apply
(
out
)
out
=
weights
.
img_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
out
=
weights
.
img_mlp_fc2
.
apply
(
out
)
out
=
weights
.
img_mlp_fc2
.
apply
(
out
)
...
@@ -251,13 +236,23 @@ class HunyuanTransformerInfer:
...
@@ -251,13 +236,23 @@ class HunyuanTransformerInfer:
txt
=
txt
+
out
txt
=
txt
+
out
return
txt
return
txt
def
infer_single_block
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
def
infer_single_block
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
weights
.
modulation
.
apply
(
out
)
out
=
weights
.
modulation
.
apply
(
out
)
mod_shift
,
mod_scale
,
mod_gate
=
out
.
chunk
(
3
,
dim
=-
1
)
mod_shift
,
mod_scale
,
mod_gate
=
out
.
chunk
(
3
,
dim
=-
1
)
if
token_replace_vec
is
not
None
:
token_replace_vec_out
=
torch
.
nn
.
functional
.
silu
(
token_replace_vec
)
token_replace_vec_out
=
weights
.
modulation
.
apply
(
token_replace_vec_out
)
tr_mod_shift
,
tr_mod_scale
,
tr_mod_gate
=
token_replace_vec_out
.
chunk
(
3
,
dim
=-
1
)
out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
x_mod
=
out
*
(
1
+
mod_scale
)
+
mod_shift
if
token_replace_vec
is
not
None
:
x_zero
=
out
[:
frist_frame_token_num
]
*
(
1
+
tr_mod_scale
)
+
tr_mod_shift
x_orig
=
out
[
frist_frame_token_num
:]
*
(
1
+
mod_scale
)
+
mod_shift
x_mod
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
x_mod
=
out
*
(
1
+
mod_scale
)
+
mod_shift
x_mod
=
weights
.
linear1
.
apply
(
x_mod
)
x_mod
=
weights
.
linear1
.
apply
(
x_mod
)
...
@@ -301,6 +296,12 @@ class HunyuanTransformerInfer:
...
@@ -301,6 +296,12 @@ class HunyuanTransformerInfer:
out
=
torch
.
nn
.
functional
.
gelu
(
mlp
,
approximate
=
"tanh"
)
out
=
torch
.
nn
.
functional
.
gelu
(
mlp
,
approximate
=
"tanh"
)
out
=
torch
.
cat
((
attn
,
out
),
1
)
out
=
torch
.
cat
((
attn
,
out
),
1
)
out
=
weights
.
linear2
.
apply
(
out
)
out
=
weights
.
linear2
.
apply
(
out
)
out
=
out
*
mod_gate
if
token_replace_vec
is
not
None
:
x_zero
=
out
[:
frist_frame_token_num
]
*
tr_mod_gate
x_orig
=
out
[
frist_frame_token_num
:]
*
mod_gate
out
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
out
=
out
*
mod_gate
x
=
x
+
out
x
=
x
+
out
return
x
return
x
lightx2v/text2v/models/networks/hunyuan/model.py
View file @
86f7f033
...
@@ -17,10 +17,11 @@ class HunyuanModel:
...
@@ -17,10 +17,11 @@ class HunyuanModel:
post_weight_class
=
HunyuanPostWeights
post_weight_class
=
HunyuanPostWeights
transformer_weight_class
=
HunyuanTransformerWeights
transformer_weight_class
=
HunyuanTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
):
def
__init__
(
self
,
model_path
,
config
,
device
,
args
):
self
.
model_path
=
model_path
self
.
model_path
=
model_path
self
.
config
=
config
self
.
config
=
config
self
.
device
=
device
self
.
device
=
device
self
.
args
=
args
self
.
_init_infer_class
()
self
.
_init_infer_class
()
self
.
_init_weights
()
self
.
_init_weights
()
self
.
_init_infer
()
self
.
_init_infer
()
...
@@ -47,7 +48,10 @@ class HunyuanModel:
...
@@ -47,7 +48,10 @@ class HunyuanModel:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
def
_load_ckpt
(
self
):
def
_load_ckpt
(
self
):
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
)
if
self
.
args
.
task
==
"t2v"
:
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
)
else
:
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt"
)
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
self
.
device
,
weights_only
=
True
)[
"module"
]
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
self
.
device
,
weights_only
=
True
)[
"module"
]
return
weight_dict
return
weight_dict
...
@@ -96,6 +100,7 @@ class HunyuanModel:
...
@@ -96,6 +100,7 @@ class HunyuanModel:
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
,
self
.
scheduler
.
freqs_sin
,
self
.
scheduler
.
guidance
,
self
.
scheduler
.
guidance
,
img_latents
=
image_encoder_output
[
"img_latents"
]
if
"img_latents"
in
image_encoder_output
else
None
,
)
)
img
,
vec
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
*
pre_infer_out
)
img
,
vec
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
*
pre_infer_out
)
self
.
scheduler
.
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
img
,
vec
,
self
.
scheduler
.
latents
.
shape
)
self
.
scheduler
.
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
img
,
vec
,
self
.
scheduler
.
latents
.
shape
)
...
...
lightx2v/text2v/models/schedulers/hunyuan/scheduler.py
View file @
86f7f033
import
torch
import
torch
import
numpy
as
np
from
diffusers.utils.torch_utils
import
randn_tensor
from
diffusers.utils.torch_utils
import
randn_tensor
from
typing
import
Union
,
Tuple
,
List
from
typing
import
Union
,
Tuple
,
List
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
,
Tuple
...
@@ -174,35 +175,108 @@ def get_nd_rotary_pos_embed(
...
@@ -174,35 +175,108 @@ def get_nd_rotary_pos_embed(
def
set_timesteps_sigmas
(
num_inference_steps
,
shift
,
device
,
num_train_timesteps
=
1000
):
def
set_timesteps_sigmas
(
num_inference_steps
,
shift
,
device
,
num_train_timesteps
=
1000
):
sigmas
=
torch
.
linspace
(
1
,
0
,
num_inference_steps
+
1
)
sigmas
=
torch
.
linspace
(
1
,
0
,
num_inference_steps
+
1
)
sigmas
=
(
shift
*
sigmas
)
/
(
1
+
(
shift
-
1
)
*
sigmas
)
sigmas
=
(
shift
*
sigmas
)
/
(
1
+
(
shift
-
1
)
*
sigmas
)
timesteps
=
(
sigmas
[:
-
1
]
*
num_train_timesteps
).
to
(
dtype
=
torch
.
b
float
16
,
device
=
device
)
timesteps
=
(
sigmas
[:
-
1
]
*
num_train_timesteps
).
to
(
dtype
=
torch
.
float
32
,
device
=
device
)
return
timesteps
,
sigmas
return
timesteps
,
sigmas
def
get_1d_rotary_pos_embed_riflex
(
dim
:
int
,
pos
:
Union
[
np
.
ndarray
,
int
],
theta
:
float
=
10000.0
,
use_real
=
False
,
k
:
Optional
[
int
]
=
None
,
L_test
:
Optional
[
int
]
=
None
,
):
"""
RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
L_test (`int`, *optional*, defaults to None): the number of frames for inference
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert
dim
%
2
==
0
if
isinstance
(
pos
,
int
):
pos
=
torch
.
arange
(
pos
)
if
isinstance
(
pos
,
np
.
ndarray
):
pos
=
torch
.
from_numpy
(
pos
)
# type: ignore # [S]
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
pos
.
device
)[:
(
dim
//
2
)].
float
()
/
dim
))
# [D/2]
# === Riflex modification start ===
# Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
# Empirical observations show that a few videos may exhibit repetition in the tail frames.
# To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
if
k
is
not
None
:
freqs
[
k
-
1
]
=
0.9
*
2
*
torch
.
pi
/
L_test
# === Riflex modification end ===
freqs
=
torch
.
outer
(
pos
,
freqs
)
# type: ignore # [S, D/2]
if
use_real
:
freqs_cos
=
freqs
.
cos
().
repeat_interleave
(
2
,
dim
=
1
).
float
()
# [S, D]
freqs_sin
=
freqs
.
sin
().
repeat_interleave
(
2
,
dim
=
1
).
float
()
# [S, D]
return
freqs_cos
,
freqs_sin
else
:
# lumina
freqs_cis
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
# complex64 # [S, D/2]
return
freqs_cis
class
HunyuanScheduler
(
BaseScheduler
):
class
HunyuanScheduler
(
BaseScheduler
):
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
,
image_encoder_output
):
super
().
__init__
(
args
)
super
().
__init__
(
args
)
self
.
infer_steps
=
self
.
args
.
infer_steps
self
.
infer_steps
=
self
.
args
.
infer_steps
self
.
image_encoder_output
=
image_encoder_output
self
.
shift
=
7.0
self
.
shift
=
7.0
self
.
timesteps
,
self
.
sigmas
=
set_timesteps_sigmas
(
self
.
infer_steps
,
self
.
shift
,
device
=
torch
.
device
(
"cuda"
))
self
.
timesteps
,
self
.
sigmas
=
set_timesteps_sigmas
(
self
.
infer_steps
,
self
.
shift
,
device
=
torch
.
device
(
"cuda"
))
assert
len
(
self
.
timesteps
)
==
self
.
infer_steps
assert
len
(
self
.
timesteps
)
==
self
.
infer_steps
self
.
embedded_guidance_scale
=
6.0
self
.
embedded_guidance_scale
=
6.0
self
.
generator
=
[
torch
.
Generator
(
"cuda"
).
manual_seed
(
seed
)
for
seed
in
[
42
]]
self
.
generator
=
[
torch
.
Generator
(
"cuda"
).
manual_seed
(
seed
)
for
seed
in
[
self
.
args
.
seed
]]
self
.
noise_pred
=
None
self
.
noise_pred
=
None
self
.
prepare_latents
(
shape
=
self
.
args
.
target_shape
,
dtype
=
torch
.
b
float16
)
self
.
prepare_latents
(
shape
=
self
.
args
.
target_shape
,
dtype
=
torch
.
float16
)
self
.
prepare_guidance
()
self
.
prepare_guidance
()
self
.
prepare_rotary_pos_embedding
(
video_length
=
self
.
args
.
target_video_length
,
height
=
self
.
args
.
target_height
,
width
=
self
.
args
.
target_width
)
if
self
.
args
.
task
==
"t2v"
:
target_height
,
target_width
=
self
.
args
.
target_height
,
self
.
args
.
target_width
else
:
target_height
,
target_width
=
self
.
image_encoder_output
[
"target_height"
],
self
.
image_encoder_output
[
"target_width"
]
self
.
prepare_rotary_pos_embedding
(
video_length
=
self
.
args
.
target_video_length
,
height
=
target_height
,
width
=
target_width
)
def
prepare_guidance
(
self
):
def
prepare_guidance
(
self
):
self
.
guidance
=
torch
.
tensor
([
self
.
embedded_guidance_scale
],
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
*
1000.0
self
.
guidance
=
torch
.
tensor
([
self
.
embedded_guidance_scale
],
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
*
1000.0
def
step_post
(
self
):
def
step_post
(
self
):
sample
=
self
.
latents
.
to
(
torch
.
float32
)
if
self
.
args
.
task
==
"t2v"
:
dt
=
self
.
sigmas
[
self
.
step_index
+
1
]
-
self
.
sigmas
[
self
.
step_index
]
sample
=
self
.
latents
.
to
(
torch
.
float32
)
prev_sample
=
sample
+
self
.
noise_pred
.
to
(
torch
.
float32
)
*
dt
dt
=
self
.
sigmas
[
self
.
step_index
+
1
]
-
self
.
sigmas
[
self
.
step_index
]
self
.
latents
=
prev_sample
self
.
latents
=
sample
+
self
.
noise_pred
.
to
(
torch
.
float32
)
*
dt
else
:
sample
=
self
.
latents
[:,
:,
1
:,
:,
:].
to
(
torch
.
float32
)
dt
=
self
.
sigmas
[
self
.
step_index
+
1
]
-
self
.
sigmas
[
self
.
step_index
]
latents
=
sample
+
self
.
noise_pred
[:,
:,
1
:,
:,
:].
to
(
torch
.
float32
)
*
dt
self
.
latents
=
torch
.
concat
([
self
.
image_encoder_output
[
"img_latents"
],
latents
],
dim
=
2
)
def
prepare_latents
(
self
,
shape
,
dtype
):
def
prepare_latents
(
self
,
shape
,
dtype
):
self
.
latents
=
randn_tensor
(
shape
,
generator
=
self
.
generator
,
device
=
torch
.
device
(
"cuda"
),
dtype
=
dtype
)
if
self
.
args
.
task
==
"t2v"
:
self
.
latents
=
randn_tensor
(
shape
,
generator
=
self
.
generator
,
device
=
torch
.
device
(
"cuda"
),
dtype
=
dtype
)
else
:
x1
=
self
.
image_encoder_output
[
"img_latents"
].
repeat
(
1
,
1
,
(
self
.
args
.
target_video_length
-
1
)
//
4
+
1
,
1
,
1
)
x0
=
randn_tensor
(
shape
,
generator
=
self
.
generator
,
device
=
torch
.
device
(
"cuda"
),
dtype
=
dtype
)
t
=
torch
.
tensor
([
0.999
]).
to
(
device
=
torch
.
device
(
"cuda"
))
self
.
latents
=
x0
*
t
+
x1
*
(
1
-
t
)
self
.
latents
=
self
.
latents
.
to
(
dtype
=
dtype
)
self
.
latents
=
torch
.
concat
([
self
.
image_encoder_output
[
"img_latents"
],
self
.
latents
[:,
:,
1
:,
:,
:]],
dim
=
2
)
def
prepare_rotary_pos_embedding
(
self
,
video_length
,
height
,
width
):
def
prepare_rotary_pos_embedding
(
self
,
video_length
,
height
,
width
):
target_ndim
=
3
target_ndim
=
3
...
@@ -230,17 +304,62 @@ class HunyuanScheduler(BaseScheduler):
...
@@ -230,17 +304,62 @@ class HunyuanScheduler(BaseScheduler):
if
len
(
rope_sizes
)
!=
target_ndim
:
if
len
(
rope_sizes
)
!=
target_ndim
:
rope_sizes
=
[
1
]
*
(
target_ndim
-
len
(
rope_sizes
))
+
rope_sizes
# time axis
rope_sizes
=
[
1
]
*
(
target_ndim
-
len
(
rope_sizes
))
+
rope_sizes
# time axis
head_dim
=
hidden_size
//
heads_num
rope_dim_list
=
rope_dim_list
if
self
.
args
.
task
==
"t2v"
:
if
rope_dim_list
is
None
:
head_dim
=
hidden_size
//
heads_num
rope_dim_list
=
[
head_dim
//
target_ndim
for
_
in
range
(
target_ndim
)]
rope_dim_list
=
rope_dim_list
assert
sum
(
rope_dim_list
)
==
head_dim
,
"sum(rope_dim_list) should equal to head_dim of attention layer"
if
rope_dim_list
is
None
:
self
.
freqs_cos
,
self
.
freqs_sin
=
get_nd_rotary_pos_embed
(
rope_dim_list
=
[
head_dim
//
target_ndim
for
_
in
range
(
target_ndim
)]
rope_dim_list
,
assert
sum
(
rope_dim_list
)
==
head_dim
,
"sum(rope_dim_list) should equal to head_dim of attention layer"
rope_sizes
,
self
.
freqs_cos
,
self
.
freqs_sin
=
get_nd_rotary_pos_embed
(
theta
=
rope_theta
,
rope_dim_list
,
use_real
=
True
,
rope_sizes
,
theta_rescale_factor
=
1
,
theta
=
rope_theta
,
)
use_real
=
True
,
self
.
freqs_cos
=
self
.
freqs_cos
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
theta_rescale_factor
=
1
,
self
.
freqs_sin
=
self
.
freqs_sin
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
)
self
.
freqs_cos
=
self
.
freqs_cos
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_sin
=
self
.
freqs_sin
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
else
:
L_test
=
rope_sizes
[
0
]
# Latent frames
L_train
=
25
# Training length from HunyuanVideo
actual_num_frames
=
video_length
# Use input video_length directly
head_dim
=
hidden_size
//
heads_num
rope_dim_list
=
rope_dim_list
or
[
head_dim
//
target_ndim
for
_
in
range
(
target_ndim
)]
assert
sum
(
rope_dim_list
)
==
head_dim
,
"sum(rope_dim_list) must equal head_dim"
if
actual_num_frames
>
192
:
k
=
2
+
((
actual_num_frames
+
3
)
//
(
4
*
L_train
))
k
=
max
(
4
,
min
(
8
,
k
))
# Compute positional grids for RIFLEx
axes_grids
=
[
torch
.
arange
(
size
,
device
=
torch
.
device
(
"cuda"
),
dtype
=
torch
.
float32
)
for
size
in
rope_sizes
]
grid
=
torch
.
meshgrid
(
*
axes_grids
,
indexing
=
"ij"
)
grid
=
torch
.
stack
(
grid
,
dim
=
0
)
# [3, t, h, w]
pos
=
grid
.
reshape
(
3
,
-
1
).
t
()
# [t * h * w, 3]
# Apply RIFLEx to temporal dimension
freqs
=
[]
for
i
in
range
(
3
):
if
i
==
0
:
# Temporal with RIFLEx
freqs_cos
,
freqs_sin
=
get_1d_rotary_pos_embed_riflex
(
rope_dim_list
[
i
],
pos
[:,
i
],
theta
=
rope_theta
,
use_real
=
True
,
k
=
k
,
L_test
=
L_test
)
else
:
# Spatial with default RoPE
freqs_cos
,
freqs_sin
=
get_1d_rotary_pos_embed_riflex
(
rope_dim_list
[
i
],
pos
[:,
i
],
theta
=
rope_theta
,
use_real
=
True
,
k
=
None
,
L_test
=
None
)
freqs
.
append
((
freqs_cos
,
freqs_sin
))
freqs_cos
=
torch
.
cat
([
f
[
0
]
for
f
in
freqs
],
dim
=
1
)
freqs_sin
=
torch
.
cat
([
f
[
1
]
for
f
in
freqs
],
dim
=
1
)
else
:
# 20250316 pftq: Original code for <= 192 frames
freqs_cos
,
freqs_sin
=
get_nd_rotary_pos_embed
(
rope_dim_list
,
rope_sizes
,
theta
=
rope_theta
,
use_real
=
True
,
theta_rescale_factor
=
1
,
)
self
.
freqs_cos
=
freqs_cos
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_sin
=
freqs_sin
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
lightx2v/text2v/models/text_encoders/hf/llava/__init__.py
0 → 100755
View file @
86f7f033
lightx2v/text2v/models/text_encoders/hf/llava/model.py
0 → 100755
View file @
86f7f033
import
torch
from
PIL
import
Image
import
numpy
as
np
import
torchvision.transforms
as
transforms
from
transformers
import
LlavaForConditionalGeneration
,
CLIPImageProcessor
,
AutoTokenizer
def
generate_crop_size_list
(
base_size
=
256
,
patch_size
=
32
,
max_ratio
=
4.0
):
"""generate crop size list
Args:
base_size (int, optional): the base size for generate bucket. Defaults to 256.
patch_size (int, optional): the stride to generate bucket. Defaults to 32.
max_ratio (float, optional): th max ratio for h or w based on base_size . Defaults to 4.0.
Returns:
list: generate crop size list
"""
num_patches
=
round
((
base_size
/
patch_size
)
**
2
)
assert
max_ratio
>=
1.0
crop_size_list
=
[]
wp
,
hp
=
num_patches
,
1
while
wp
>
0
:
if
max
(
wp
,
hp
)
/
min
(
wp
,
hp
)
<=
max_ratio
:
crop_size_list
.
append
((
wp
*
patch_size
,
hp
*
patch_size
))
if
(
hp
+
1
)
*
wp
<=
num_patches
:
hp
+=
1
else
:
wp
-=
1
return
crop_size_list
def
get_closest_ratio
(
height
:
float
,
width
:
float
,
ratios
:
list
,
buckets
:
list
):
"""get the closest ratio in the buckets
Args:
height (float): video height
width (float): video width
ratios (list): video aspect ratio
buckets (list): buckets generate by `generate_crop_size_list`
Returns:
the closest ratio in the buckets and the corresponding ratio
"""
aspect_ratio
=
float
(
height
)
/
float
(
width
)
diff_ratios
=
ratios
-
aspect_ratio
if
aspect_ratio
>=
1
:
indices
=
[(
index
,
x
)
for
index
,
x
in
enumerate
(
diff_ratios
)
if
x
<=
0
]
else
:
indices
=
[(
index
,
x
)
for
index
,
x
in
enumerate
(
diff_ratios
)
if
x
>
0
]
closest_ratio_id
=
min
(
indices
,
key
=
lambda
pair
:
abs
(
pair
[
1
]))[
0
]
closest_size
=
buckets
[
closest_ratio_id
]
closest_ratio
=
ratios
[
closest_ratio_id
]
return
closest_size
,
closest_ratio
class
TextEncoderHFLlavaModel
:
def
__init__
(
self
,
model_path
,
device
):
self
.
device
=
device
self
.
model_path
=
model_path
self
.
init
()
self
.
load
()
def
init
(
self
):
self
.
max_length
=
359
self
.
hidden_state_skip_layer
=
2
self
.
crop_start
=
103
self
.
double_return_token_id
=
271
self
.
image_emb_len
=
576
self
.
text_crop_start
=
self
.
crop_start
-
1
+
self
.
image_emb_len
self
.
image_crop_start
=
5
self
.
image_crop_end
=
581
self
.
image_embed_interleave
=
4
self
.
prompt_template
=
(
"<|start_header_id|>system<|end_header_id|>
\n\n
<image>
\n
Describe the video by detailing the following aspects according to the reference image: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>
\n\n
"
"<|start_header_id|>user<|end_header_id|>
\n\n
{}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
def
load
(
self
):
self
.
model
=
LlavaForConditionalGeneration
.
from_pretrained
(
self
.
model_path
,
low_cpu_mem_usage
=
True
).
to
(
torch
.
float16
).
to
(
self
.
device
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
model_path
,
padding_side
=
"right"
)
self
.
processor
=
CLIPImageProcessor
.
from_pretrained
(
self
.
model_path
)
def
to_cpu
(
self
):
self
.
model
=
self
.
model
.
to
(
"cpu"
)
def
to_cuda
(
self
):
self
.
model
=
self
.
model
.
to
(
"cuda"
)
@
torch
.
no_grad
()
def
infer
(
self
,
text
,
img
,
args
):
# if args.cpu_offload:
# self.to_cuda()
text
=
self
.
prompt_template
.
format
(
text
)
print
(
f
"text:
{
text
}
"
)
tokens
=
self
.
tokenizer
(
text
,
return_length
=
False
,
return_overflowing_tokens
=
False
,
return_attention_mask
=
True
,
truncation
=
True
,
max_length
=
self
.
max_length
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
).
to
(
"cuda"
)
image_outputs
=
self
.
processor
(
img
,
return_tensors
=
"pt"
)[
"pixel_values"
].
to
(
self
.
device
)
attention_mask
=
tokens
[
"attention_mask"
].
to
(
self
.
device
)
outputs
=
self
.
model
(
input_ids
=
tokens
[
"input_ids"
],
attention_mask
=
attention_mask
,
output_hidden_states
=
True
,
pixel_values
=
image_outputs
)
last_hidden_state
=
outputs
.
hidden_states
[
-
(
self
.
hidden_state_skip_layer
+
1
)]
batch_indices
,
last_double_return_token_indices
=
torch
.
where
(
tokens
[
"input_ids"
]
==
self
.
double_return_token_id
)
last_double_return_token_indices
=
last_double_return_token_indices
.
reshape
(
1
,
-
1
)[:,
-
1
]
assistant_crop_start
=
last_double_return_token_indices
-
1
+
self
.
image_emb_len
-
4
assistant_crop_end
=
last_double_return_token_indices
-
1
+
self
.
image_emb_len
attention_mask_assistant_crop_start
=
last_double_return_token_indices
-
4
attention_mask_assistant_crop_end
=
last_double_return_token_indices
text_last_hidden_state
=
torch
.
cat
([
last_hidden_state
[
0
,
self
.
text_crop_start
:
assistant_crop_start
[
0
].
item
()],
last_hidden_state
[
0
,
assistant_crop_end
[
0
].
item
()
:]])
text_attention_mask
=
torch
.
cat
([
attention_mask
[
0
,
self
.
crop_start
:
attention_mask_assistant_crop_start
[
0
].
item
()],
attention_mask
[
0
,
attention_mask_assistant_crop_end
[
0
].
item
()
:]])
image_last_hidden_state
=
last_hidden_state
[
0
,
self
.
image_crop_start
:
self
.
image_crop_end
]
image_attention_mask
=
torch
.
ones
(
image_last_hidden_state
.
shape
[
0
]).
to
(
last_hidden_state
.
device
).
to
(
attention_mask
.
dtype
)
text_last_hidden_state
.
unsqueeze_
(
0
)
text_attention_mask
.
unsqueeze_
(
0
)
image_last_hidden_state
.
unsqueeze_
(
0
)
image_attention_mask
.
unsqueeze_
(
0
)
image_last_hidden_state
=
image_last_hidden_state
[:,
::
self
.
image_embed_interleave
,
:]
image_attention_mask
=
image_attention_mask
[:,
::
self
.
image_embed_interleave
]
last_hidden_state
=
torch
.
cat
([
image_last_hidden_state
,
text_last_hidden_state
],
dim
=
1
)
attention_mask
=
torch
.
cat
([
image_attention_mask
,
text_attention_mask
],
dim
=
1
)
# if args.cpu_offload:
# self.to_cpu()
return
last_hidden_state
,
attention_mask
if
__name__
==
"__main__"
:
model
=
TextEncoderHFLlavaModel
(
"/mtc/yongyang/models/x2v_models/hunyuan/lightx2v_format/i2v/text_encoder_i2v"
,
torch
.
device
(
"cuda"
))
text
=
"An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
img_path
=
"/mtc/yongyang/projects/lightx2v/assets/inputs/imgs/img_1.jpg"
img
=
Image
.
open
(
img_path
).
convert
(
"RGB"
)
outputs
=
model
.
infer
(
text
,
img
,
None
)
print
(
outputs
)
lightx2v/text2v/models/video_encoders/hf/autoencoder_kl_causal_3d/model.py
View file @
86f7f033
import
os
import
os
import
torch
import
torch
from
.autoencoder_kl_causal_3d
import
AutoencoderKLCausal3D
from
.autoencoder_kl_causal_3d
import
AutoencoderKLCausal3D
,
DiagonalGaussianDistribution
class
VideoEncoderKLCausal3DModel
:
class
VideoEncoderKLCausal3DModel
:
def
__init__
(
self
,
model_path
,
dtype
,
device
):
def
__init__
(
self
,
model_path
,
dtype
,
device
,
args
):
self
.
model_path
=
model_path
self
.
model_path
=
model_path
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
self
.
args
=
args
self
.
load
()
self
.
load
()
def
load
(
self
):
def
load
(
self
):
self
.
vae_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-t2v-720p/vae"
)
if
self
.
args
.
task
==
"t2v"
:
self
.
vae_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-t2v-720p/vae"
)
else
:
self
.
vae_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-i2v-720p/vae"
)
config
=
AutoencoderKLCausal3D
.
load_config
(
self
.
vae_path
)
config
=
AutoencoderKLCausal3D
.
load_config
(
self
.
vae_path
)
self
.
model
=
AutoencoderKLCausal3D
.
from_config
(
config
)
self
.
model
=
AutoencoderKLCausal3D
.
from_config
(
config
)
ckpt
=
torch
.
load
(
os
.
path
.
join
(
self
.
vae_path
,
"pytorch_model.pt"
),
map_location
=
"cpu"
,
weights_only
=
True
)
ckpt
=
torch
.
load
(
os
.
path
.
join
(
self
.
vae_path
,
"pytorch_model.pt"
),
map_location
=
"cpu"
,
weights_only
=
True
)
...
@@ -39,6 +43,12 @@ class VideoEncoderKLCausal3DModel:
...
@@ -39,6 +43,12 @@ class VideoEncoderKLCausal3DModel:
self
.
to_cpu
()
self
.
to_cpu
()
return
image
return
image
def
encode
(
self
,
x
,
args
):
h
=
self
.
model
.
encoder
(
x
)
moments
=
self
.
model
.
quant_conv
(
h
)
posterior
=
DiagonalGaussianDistribution
(
moments
)
return
posterior
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
model_path
=
""
model_path
=
""
...
...
scripts/run_hunyuan_i2v.sh
0 → 100755
View file @
86f7f033
#!/bin/bash
# set path and first
lightx2v_path
=
""
model_path
=
""
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: 0, change at shell script or set env variable."
cuda_devices
=
"0"
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
lightx2v_path
}
"
]
;
then
echo
"Error: lightx2v_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
python
${
lightx2v_path
}
/lightx2v/__main__.py
\
--model_cls
hunyuan
\
--model_path
$model_path
\
--task
i2v
\
--prompt
"An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
\
--image_path
${
lightx2v_path
}
/assets/inputs/imgs/img_1.jpg
\
--infer_steps
20
\
--target_video_length
33
\
--target_height
720
\
--target_width
1280
\
--attention_type
flash_attn2
\
--save_video_path
./output_lightx2v_hy_i2v.mp4
\
--seed
0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment