Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
efb4d161
Commit
efb4d161
authored
Apr 20, 2025
by
helloyongyang
Browse files
删除args传参,统一使用config传递,简化代码
parent
f4b343f6
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
209 additions
and
202 deletions
+209
-202
lightx2v/__main__.py
lightx2v/__main__.py
+107
-126
lightx2v/text2v/models/networks/hunyuan/infer/post_infer.py
lightx2v/text2v/models/networks/hunyuan/infer/post_infer.py
+7
-4
lightx2v/text2v/models/networks/hunyuan/infer/pre_infer.py
lightx2v/text2v/models/networks/hunyuan/infer/pre_infer.py
+19
-5
lightx2v/text2v/models/networks/hunyuan/model.py
lightx2v/text2v/models/networks/hunyuan/model.py
+10
-18
lightx2v/text2v/models/networks/wan/infer/post_infer.py
lightx2v/text2v/models/networks/wan/infer/post_infer.py
+3
-0
lightx2v/text2v/models/networks/wan/infer/pre_infer.py
lightx2v/text2v/models/networks/wan/infer/pre_infer.py
+14
-1
lightx2v/text2v/models/networks/wan/model.py
lightx2v/text2v/models/networks/wan/model.py
+6
-24
lightx2v/text2v/models/text_encoders/hf/clip/model.py
lightx2v/text2v/models/text_encoders/hf/clip/model.py
+3
-3
lightx2v/text2v/models/text_encoders/hf/llama/model.py
lightx2v/text2v/models/text_encoders/hf/llama/model.py
+3
-3
lightx2v/text2v/models/text_encoders/hf/llava/model.py
lightx2v/text2v/models/text_encoders/hf/llava/model.py
+5
-5
lightx2v/text2v/models/text_encoders/hf/t5/model.py
lightx2v/text2v/models/text_encoders/hf/t5/model.py
+3
-3
lightx2v/text2v/models/video_encoders/hf/autoencoder_kl_causal_3d/model.py
...odels/video_encoders/hf/autoencoder_kl_causal_3d/model.py
+7
-7
lightx2v/text2v/models/video_encoders/hf/wan/vae.py
lightx2v/text2v/models/video_encoders/hf/wan/vae.py
+3
-3
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+19
-0
No files found.
lightx2v/__main__.py
View file @
efb4d161
...
@@ -29,6 +29,7 @@ from lightx2v.text2v.models.video_encoders.hf.wan.vae import WanVAE
...
@@ -29,6 +29,7 @@ from lightx2v.text2v.models.video_encoders.hf.wan.vae import WanVAE
from
lightx2v.utils.utils
import
save_videos_grid
,
seed_all
,
cache_video
from
lightx2v.utils.utils
import
save_videos_grid
,
seed_all
,
cache_video
from
lightx2v.common.ops
import
*
from
lightx2v.common.ops
import
*
from
lightx2v.image2v.models.wan.model
import
CLIPModel
from
lightx2v.image2v.models.wan.model
import
CLIPModel
from
lightx2v.utils.set_config
import
set_config
@
contextmanager
@
contextmanager
...
@@ -41,92 +42,92 @@ def time_duration(label: str = ""):
...
@@ -41,92 +42,92 @@ def time_duration(label: str = ""):
print
(
f
"==>
{
label
}
start:
{
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
start_time
))
}
cost
{
end_time
-
start_time
:.
2
f
}
seconds"
)
print
(
f
"==>
{
label
}
start:
{
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
start_time
))
}
cost
{
end_time
-
start_time
:.
2
f
}
seconds"
)
def
load_models
(
args
,
model_
config
):
def
load_models
(
config
):
if
model_
config
[
"parallel_attn_type"
]:
if
config
[
"parallel_attn_type"
]:
cur_rank
=
dist
.
get_rank
()
# 获取当前进程的 rank
cur_rank
=
dist
.
get_rank
()
# 获取当前进程的 rank
torch
.
cuda
.
set_device
(
cur_rank
)
# 设置当前进程的 CUDA 设备
torch
.
cuda
.
set_device
(
cur_rank
)
# 设置当前进程的 CUDA 设备
image_encoder
=
None
image_encoder
=
None
if
args
.
cpu_offload
:
if
config
.
cpu_offload
:
init_device
=
torch
.
device
(
"cpu"
)
init_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
init_device
=
torch
.
device
(
"cuda"
)
init_device
=
torch
.
device
(
"cuda"
)
if
args
.
model_cls
==
"hunyuan"
:
if
config
.
model_cls
==
"hunyuan"
:
if
args
.
task
==
"t2v"
:
if
config
.
task
==
"t2v"
:
text_encoder_1
=
TextEncoderHFLlamaModel
(
os
.
path
.
join
(
args
.
model_path
,
"text_encoder"
),
init_device
)
text_encoder_1
=
TextEncoderHFLlamaModel
(
os
.
path
.
join
(
config
.
model_path
,
"text_encoder"
),
init_device
)
else
:
else
:
text_encoder_1
=
TextEncoderHFLlavaModel
(
os
.
path
.
join
(
args
.
model_path
,
"text_encoder_i2v"
),
init_device
)
text_encoder_1
=
TextEncoderHFLlavaModel
(
os
.
path
.
join
(
config
.
model_path
,
"text_encoder_i2v"
),
init_device
)
text_encoder_2
=
TextEncoderHFClipModel
(
os
.
path
.
join
(
args
.
model_path
,
"text_encoder_2"
),
init_device
)
text_encoder_2
=
TextEncoderHFClipModel
(
os
.
path
.
join
(
config
.
model_path
,
"text_encoder_2"
),
init_device
)
text_encoders
=
[
text_encoder_1
,
text_encoder_2
]
text_encoders
=
[
text_encoder_1
,
text_encoder_2
]
model
=
HunyuanModel
(
args
.
model_path
,
model_
config
,
init_device
,
args
)
model
=
HunyuanModel
(
config
.
model_path
,
config
,
init_device
,
config
)
vae_model
=
VideoEncoderKLCausal3DModel
(
args
.
model_path
,
dtype
=
torch
.
float16
,
device
=
init_device
,
args
=
args
)
vae_model
=
VideoEncoderKLCausal3DModel
(
config
.
model_path
,
dtype
=
torch
.
float16
,
device
=
init_device
,
config
=
config
)
elif
args
.
model_cls
==
"wan2.1"
:
elif
config
.
model_cls
==
"wan2.1"
:
with
time_duration
(
"Load Text Encoder"
):
with
time_duration
(
"Load Text Encoder"
):
text_encoder
=
T5EncoderModel
(
text_encoder
=
T5EncoderModel
(
text_len
=
model_
config
[
"text_len"
],
text_len
=
config
[
"text_len"
],
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
device
=
init_device
,
device
=
init_device
,
checkpoint_path
=
os
.
path
.
join
(
args
.
model_path
,
"models_t5_umt5-xxl-enc-bf16.pth"
),
checkpoint_path
=
os
.
path
.
join
(
config
.
model_path
,
"models_t5_umt5-xxl-enc-bf16.pth"
),
tokenizer_path
=
os
.
path
.
join
(
args
.
model_path
,
"google/umt5-xxl"
),
tokenizer_path
=
os
.
path
.
join
(
config
.
model_path
,
"google/umt5-xxl"
),
shard_fn
=
None
,
shard_fn
=
None
,
)
)
text_encoders
=
[
text_encoder
]
text_encoders
=
[
text_encoder
]
with
time_duration
(
"Load Wan Model"
):
with
time_duration
(
"Load Wan Model"
):
model
=
WanModel
(
args
.
model_path
,
model_
config
,
init_device
)
model
=
WanModel
(
config
.
model_path
,
config
,
init_device
)
if
args
.
lora_path
:
if
config
.
lora_path
:
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_wrapper
=
WanLoraWrapper
(
model
)
with
time_duration
(
"Load LoRA Model"
):
with
time_duration
(
"Load LoRA Model"
):
lora_name
=
lora_wrapper
.
load_lora
(
args
.
lora_path
)
lora_name
=
lora_wrapper
.
load_lora
(
config
.
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
args
.
strength_model
)
lora_wrapper
.
apply_lora
(
lora_name
,
config
.
strength_model
)
print
(
f
"Loaded LoRA:
{
lora_name
}
"
)
print
(
f
"Loaded LoRA:
{
lora_name
}
"
)
with
time_duration
(
"Load WAN VAE Model"
):
with
time_duration
(
"Load WAN VAE Model"
):
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
args
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
args
.
parallel_vae
)
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
config
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
config
.
parallel_vae
)
if
args
.
task
==
"i2v"
:
if
config
.
task
==
"i2v"
:
with
time_duration
(
"Load Image Encoder"
):
with
time_duration
(
"Load Image Encoder"
):
image_encoder
=
CLIPModel
(
image_encoder
=
CLIPModel
(
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
device
=
init_device
,
device
=
init_device
,
checkpoint_path
=
os
.
path
.
join
(
args
.
model_path
,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
),
checkpoint_path
=
os
.
path
.
join
(
config
.
model_path
,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
),
tokenizer_path
=
os
.
path
.
join
(
args
.
model_path
,
"xlm-roberta-large"
),
tokenizer_path
=
os
.
path
.
join
(
config
.
model_path
,
"xlm-roberta-large"
),
)
)
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported model class:
{
args
.
model_cls
}
"
)
raise
NotImplementedError
(
f
"Unsupported model class:
{
config
.
model_cls
}
"
)
return
model
,
text_encoders
,
vae_model
,
image_encoder
return
model
,
text_encoders
,
vae_model
,
image_encoder
def
set_target_shape
(
args
,
image_encoder_output
):
def
set_target_shape
(
config
,
image_encoder_output
):
if
args
.
model_cls
==
"hunyuan"
:
if
config
.
model_cls
==
"hunyuan"
:
if
args
.
task
==
"t2v"
:
if
config
.
task
==
"t2v"
:
vae_scale_factor
=
2
**
(
4
-
1
)
vae_scale_factor
=
2
**
(
4
-
1
)
args
.
target_shape
=
(
config
.
target_shape
=
(
1
,
1
,
16
,
16
,
(
args
.
target_video_length
-
1
)
//
4
+
1
,
(
config
.
target_video_length
-
1
)
//
4
+
1
,
int
(
args
.
target_height
)
//
vae_scale_factor
,
int
(
config
.
target_height
)
//
vae_scale_factor
,
int
(
args
.
target_width
)
//
vae_scale_factor
,
int
(
config
.
target_width
)
//
vae_scale_factor
,
)
)
elif
args
.
task
==
"i2v"
:
elif
config
.
task
==
"i2v"
:
vae_scale_factor
=
2
**
(
4
-
1
)
vae_scale_factor
=
2
**
(
4
-
1
)
args
.
target_shape
=
(
config
.
target_shape
=
(
1
,
1
,
16
,
16
,
(
args
.
target_video_length
-
1
)
//
4
+
1
,
(
config
.
target_video_length
-
1
)
//
4
+
1
,
int
(
image_encoder_output
[
"target_height"
])
//
vae_scale_factor
,
int
(
image_encoder_output
[
"target_height"
])
//
vae_scale_factor
,
int
(
image_encoder_output
[
"target_width"
])
//
vae_scale_factor
,
int
(
image_encoder_output
[
"target_width"
])
//
vae_scale_factor
,
)
)
elif
args
.
model_cls
==
"wan2.1"
:
elif
config
.
model_cls
==
"wan2.1"
:
if
args
.
task
==
"i2v"
:
if
config
.
task
==
"i2v"
:
args
.
target_shape
=
(
16
,
21
,
args
.
lat_h
,
args
.
lat_w
)
config
.
target_shape
=
(
16
,
21
,
config
.
lat_h
,
config
.
lat_w
)
elif
args
.
task
==
"t2v"
:
elif
config
.
task
==
"t2v"
:
args
.
target_shape
=
(
config
.
target_shape
=
(
16
,
16
,
(
args
.
target_video_length
-
1
)
//
4
+
1
,
(
config
.
target_video_length
-
1
)
//
4
+
1
,
int
(
args
.
target_height
)
//
args
.
vae_stride
[
1
],
int
(
config
.
target_height
)
//
config
.
vae_stride
[
1
],
int
(
args
.
target_width
)
//
args
.
vae_stride
[
2
],
int
(
config
.
target_width
)
//
config
.
vae_stride
[
2
],
)
)
...
@@ -161,9 +162,9 @@ def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
...
@@ -161,9 +162,9 @@ def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
return
closest_size
,
closest_ratio
return
closest_size
,
closest_ratio
def
run_image_encoder
(
args
,
image_encoder
,
vae_model
):
def
run_image_encoder
(
config
,
image_encoder
,
vae_model
):
if
args
.
model_cls
==
"hunyuan"
:
if
config
.
model_cls
==
"hunyuan"
:
img
=
Image
.
open
(
args
.
image_path
).
convert
(
"RGB"
)
img
=
Image
.
open
(
config
.
image_path
).
convert
(
"RGB"
)
origin_size
=
img
.
size
origin_size
=
img
.
size
i2v_resolution
=
"720p"
i2v_resolution
=
"720p"
...
@@ -190,7 +191,7 @@ def run_image_encoder(args, image_encoder, vae_model):
...
@@ -190,7 +191,7 @@ def run_image_encoder(args, image_encoder, vae_model):
semantic_image_pixel_values
=
[
ref_image_transform
(
img
)]
semantic_image_pixel_values
=
[
ref_image_transform
(
img
)]
semantic_image_pixel_values
=
torch
.
cat
(
semantic_image_pixel_values
).
unsqueeze
(
0
).
unsqueeze
(
2
).
to
(
torch
.
float16
).
to
(
torch
.
device
(
"cuda"
))
semantic_image_pixel_values
=
torch
.
cat
(
semantic_image_pixel_values
).
unsqueeze
(
0
).
unsqueeze
(
2
).
to
(
torch
.
float16
).
to
(
torch
.
device
(
"cuda"
))
img_latents
=
vae_model
.
encode
(
semantic_image_pixel_values
,
args
).
mode
()
img_latents
=
vae_model
.
encode
(
semantic_image_pixel_values
,
config
).
mode
()
scaling_factor
=
0.476986
scaling_factor
=
0.476986
img_latents
.
mul_
(
scaling_factor
)
img_latents
.
mul_
(
scaling_factor
)
...
@@ -199,20 +200,20 @@ def run_image_encoder(args, image_encoder, vae_model):
...
@@ -199,20 +200,20 @@ def run_image_encoder(args, image_encoder, vae_model):
return
{
"img"
:
img
,
"img_latents"
:
img_latents
,
"target_height"
:
target_height
,
"target_width"
:
target_width
}
return
{
"img"
:
img
,
"img_latents"
:
img_latents
,
"target_height"
:
target_height
,
"target_width"
:
target_width
}
elif
args
.
model_cls
==
"wan2.1"
:
elif
config
.
model_cls
==
"wan2.1"
:
img
=
Image
.
open
(
args
.
image_path
).
convert
(
"RGB"
)
img
=
Image
.
open
(
config
.
image_path
).
convert
(
"RGB"
)
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
clip_encoder_out
=
image_encoder
.
visual
([
img
[:,
None
,
:,
:]],
args
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
clip_encoder_out
=
image_encoder
.
visual
([
img
[:,
None
,
:,
:]],
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
h
,
w
=
img
.
shape
[
1
:]
h
,
w
=
img
.
shape
[
1
:]
aspect_ratio
=
h
/
w
aspect_ratio
=
h
/
w
max_area
=
args
.
target_height
*
args
.
target_width
max_area
=
config
.
target_height
*
config
.
target_width
lat_h
=
round
(
np
.
sqrt
(
max_area
*
aspect_ratio
)
//
args
.
vae_stride
[
1
]
//
args
.
patch_size
[
1
]
*
args
.
patch_size
[
1
])
lat_h
=
round
(
np
.
sqrt
(
max_area
*
aspect_ratio
)
//
config
.
vae_stride
[
1
]
//
config
.
patch_size
[
1
]
*
config
.
patch_size
[
1
])
lat_w
=
round
(
np
.
sqrt
(
max_area
/
aspect_ratio
)
//
args
.
vae_stride
[
2
]
//
args
.
patch_size
[
2
]
*
args
.
patch_size
[
2
])
lat_w
=
round
(
np
.
sqrt
(
max_area
/
aspect_ratio
)
//
config
.
vae_stride
[
2
]
//
config
.
patch_size
[
2
]
*
config
.
patch_size
[
2
])
h
=
lat_h
*
args
.
vae_stride
[
1
]
h
=
lat_h
*
config
.
vae_stride
[
1
]
w
=
lat_w
*
args
.
vae_stride
[
2
]
w
=
lat_w
*
config
.
vae_stride
[
2
]
args
.
lat_h
=
lat_h
config
.
lat_h
=
lat_h
args
.
lat_w
=
lat_w
config
.
lat_w
=
lat_w
msk
=
torch
.
ones
(
1
,
81
,
lat_h
,
lat_w
,
device
=
torch
.
device
(
"cuda"
))
msk
=
torch
.
ones
(
1
,
81
,
lat_h
,
lat_w
,
device
=
torch
.
device
(
"cuda"
))
msk
[:,
1
:]
=
0
msk
[:,
1
:]
=
0
...
@@ -220,64 +221,64 @@ def run_image_encoder(args, image_encoder, vae_model):
...
@@ -220,64 +221,64 @@ def run_image_encoder(args, image_encoder, vae_model):
msk
=
msk
.
view
(
1
,
msk
.
shape
[
1
]
//
4
,
4
,
lat_h
,
lat_w
)
msk
=
msk
.
view
(
1
,
msk
.
shape
[
1
]
//
4
,
4
,
lat_h
,
lat_w
)
msk
=
msk
.
transpose
(
1
,
2
)[
0
]
msk
=
msk
.
transpose
(
1
,
2
)[
0
]
vae_encode_out
=
vae_model
.
encode
(
vae_encode_out
=
vae_model
.
encode
(
[
torch
.
concat
([
torch
.
nn
.
functional
.
interpolate
(
img
[
None
].
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
zeros
(
3
,
80
,
h
,
w
)],
dim
=
1
).
cuda
()],
args
[
torch
.
concat
([
torch
.
nn
.
functional
.
interpolate
(
img
[
None
].
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
zeros
(
3
,
80
,
h
,
w
)],
dim
=
1
).
cuda
()],
config
)[
0
]
)[
0
]
vae_encode_out
=
torch
.
concat
([
msk
,
vae_encode_out
]).
to
(
torch
.
bfloat16
)
vae_encode_out
=
torch
.
concat
([
msk
,
vae_encode_out
]).
to
(
torch
.
bfloat16
)
return
{
"clip_encoder_out"
:
clip_encoder_out
,
"vae_encode_out"
:
vae_encode_out
}
return
{
"clip_encoder_out"
:
clip_encoder_out
,
"vae_encode_out"
:
vae_encode_out
}
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported model class:
{
args
.
model_cls
}
"
)
raise
NotImplementedError
(
f
"Unsupported model class:
{
config
.
model_cls
}
"
)
def
run_text_encoder
(
args
,
text
,
text_encoders
,
model_
config
,
image_encoder_output
):
def
run_text_encoder
(
text
,
text_encoders
,
config
,
image_encoder_output
):
text_encoder_output
=
{}
text_encoder_output
=
{}
if
args
.
model_cls
==
"hunyuan"
:
if
config
.
model_cls
==
"hunyuan"
:
for
i
,
encoder
in
enumerate
(
text_encoders
):
for
i
,
encoder
in
enumerate
(
text_encoders
):
if
args
.
task
==
"i2v"
and
i
==
0
:
if
config
.
task
==
"i2v"
and
i
==
0
:
text_state
,
attention_mask
=
encoder
.
infer
(
text
,
image_encoder_output
[
"img"
],
args
)
text_state
,
attention_mask
=
encoder
.
infer
(
text
,
image_encoder_output
[
"img"
],
config
)
else
:
else
:
text_state
,
attention_mask
=
encoder
.
infer
(
text
,
args
)
text_state
,
attention_mask
=
encoder
.
infer
(
text
,
config
)
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_text_states"
]
=
text_state
.
to
(
dtype
=
torch
.
bfloat16
)
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_text_states"
]
=
text_state
.
to
(
dtype
=
torch
.
bfloat16
)
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_attention_mask"
]
=
attention_mask
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_attention_mask"
]
=
attention_mask
elif
args
.
model_cls
==
"wan2.1"
:
elif
config
.
model_cls
==
"wan2.1"
:
n_prompt
=
model_
config
.
get
(
"sample_neg_prompt"
,
""
)
n_prompt
=
config
.
get
(
"sample_neg_prompt"
,
""
)
context
=
text_encoders
[
0
].
infer
([
text
],
args
)
context
=
text_encoders
[
0
].
infer
([
text
],
config
)
context_null
=
text_encoders
[
0
].
infer
([
n_prompt
if
n_prompt
else
""
],
args
)
context_null
=
text_encoders
[
0
].
infer
([
n_prompt
if
n_prompt
else
""
],
config
)
text_encoder_output
[
"context"
]
=
context
text_encoder_output
[
"context"
]
=
context
text_encoder_output
[
"context_null"
]
=
context_null
text_encoder_output
[
"context_null"
]
=
context_null
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported model type:
{
args
.
model_cls
}
"
)
raise
NotImplementedError
(
f
"Unsupported model type:
{
config
.
model_cls
}
"
)
return
text_encoder_output
return
text_encoder_output
def
init_scheduler
(
args
,
image_encoder_output
):
def
init_scheduler
(
config
,
image_encoder_output
):
if
args
.
model_cls
==
"hunyuan"
:
if
config
.
model_cls
==
"hunyuan"
:
if
args
.
feature_caching
==
"NoCaching"
:
if
config
.
feature_caching
==
"NoCaching"
:
scheduler
=
HunyuanScheduler
(
args
,
image_encoder_output
)
scheduler
=
HunyuanScheduler
(
config
,
image_encoder_output
)
elif
args
.
feature_caching
==
"Tea"
:
elif
config
.
feature_caching
==
"Tea"
:
scheduler
=
HunyuanSchedulerTeaCaching
(
args
,
image_encoder_output
)
scheduler
=
HunyuanSchedulerTeaCaching
(
config
,
image_encoder_output
)
elif
args
.
feature_caching
==
"TaylorSeer"
:
elif
config
.
feature_caching
==
"TaylorSeer"
:
scheduler
=
HunyuanSchedulerTaylorCaching
(
args
,
image_encoder_output
)
scheduler
=
HunyuanSchedulerTaylorCaching
(
config
,
image_encoder_output
)
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
args
.
feature_caching
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
config
.
feature_caching
}
"
)
elif
args
.
model_cls
==
"wan2.1"
:
elif
config
.
model_cls
==
"wan2.1"
:
if
args
.
feature_caching
==
"NoCaching"
:
if
config
.
feature_caching
==
"NoCaching"
:
scheduler
=
WanScheduler
(
args
)
scheduler
=
WanScheduler
(
config
)
elif
args
.
feature_caching
==
"Tea"
:
elif
config
.
feature_caching
==
"Tea"
:
scheduler
=
WanSchedulerTeaCaching
(
args
)
scheduler
=
WanSchedulerTeaCaching
(
config
)
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
args
.
feature_caching
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
config
.
feature_caching
}
"
)
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported model class:
{
args
.
model_cls
}
"
)
raise
NotImplementedError
(
f
"Unsupported model class:
{
config
.
model_cls
}
"
)
return
scheduler
return
scheduler
def
run_main_inference
(
args
,
model
,
text_encoder_output
,
image_encoder_out
put
):
def
run_main_inference
(
model
,
in
put
s
):
for
step_index
in
range
(
model
.
scheduler
.
infer_steps
):
for
step_index
in
range
(
model
.
scheduler
.
infer_steps
):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
time1
=
time
.
time
()
time1
=
time
.
time
()
...
@@ -287,7 +288,7 @@ def run_main_inference(args, model, text_encoder_output, image_encoder_output):
...
@@ -287,7 +288,7 @@ def run_main_inference(args, model, text_encoder_output, image_encoder_output):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
time2
=
time
.
time
()
time2
=
time
.
time
()
model
.
infer
(
text_encoder_output
,
image_encoder_output
,
arg
s
)
model
.
infer
(
input
s
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
time3
=
time
.
time
()
time3
=
time
.
time
()
...
@@ -304,8 +305,8 @@ def run_main_inference(args, model, text_encoder_output, image_encoder_output):
...
@@ -304,8 +305,8 @@ def run_main_inference(args, model, text_encoder_output, image_encoder_output):
return
model
.
scheduler
.
latents
,
model
.
scheduler
.
generator
return
model
.
scheduler
.
latents
,
model
.
scheduler
.
generator
def
run_vae
(
latents
,
generator
,
args
):
def
run_vae
(
latents
,
generator
,
config
):
images
=
vae_model
.
decode
(
latents
,
generator
=
generator
,
args
=
args
)
images
=
vae_model
.
decode
(
latents
,
generator
=
generator
,
config
=
config
)
return
images
return
images
...
@@ -348,69 +349,49 @@ if __name__ == "__main__":
...
@@ -348,69 +349,49 @@ if __name__ == "__main__":
seed_all
(
args
.
seed
)
seed_all
(
args
.
seed
)
if
args
.
parallel_attn_type
:
config
=
set_config
(
args
)
if
config
.
parallel_attn_type
:
dist
.
init_process_group
(
backend
=
"nccl"
)
dist
.
init_process_group
(
backend
=
"nccl"
)
if
args
.
mm_config
:
print
(
f
"config:
{
config
}
"
)
mm_config
=
json
.
loads
(
args
.
mm_config
)
else
:
mm_config
=
None
model_config
=
{
"model_cls"
:
args
.
model_cls
,
"task"
:
args
.
task
,
"attention_type"
:
args
.
attention_type
,
"sample_neg_prompt"
:
args
.
sample_neg_prompt
,
"mm_config"
:
mm_config
,
"do_mm_calib"
:
args
.
do_mm_calib
,
"cpu_offload"
:
args
.
cpu_offload
,
"feature_caching"
:
args
.
feature_caching
,
"parallel_attn_type"
:
args
.
parallel_attn_type
,
"parallel_vae"
:
args
.
parallel_vae
,
"use_bfloat16"
:
args
.
use_bfloat16
,
}
if
args
.
config_path
is
not
None
:
with
open
(
args
.
config_path
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
model_config
.
update
(
config
)
print
(
f
"model_config:
{
model_config
}
"
)
with
time_duration
(
"Load models"
):
with
time_duration
(
"Load models"
):
model
,
text_encoders
,
vae_model
,
image_encoder
=
load_models
(
args
,
model_
config
)
model
,
text_encoders
,
vae_model
,
image_encoder
=
load_models
(
config
)
if
args
.
task
in
[
"i2v"
]:
if
config
[
"
task
"
]
in
[
"i2v"
]:
image_encoder_output
=
run_image_encoder
(
args
,
image_encoder
,
vae_model
)
image_encoder_output
=
run_image_encoder
(
config
,
image_encoder
,
vae_model
)
else
:
else
:
image_encoder_output
=
{
"clip_encoder_out"
:
None
,
"vae_encode_out"
:
None
}
image_encoder_output
=
{
"clip_encoder_out"
:
None
,
"vae_encode_out"
:
None
}
with
time_duration
(
"Run Text Encoder"
):
with
time_duration
(
"Run Text Encoder"
):
text_encoder_output
=
run_text_encoder
(
args
,
args
.
prompt
,
text_encoders
,
model_config
,
image_encoder_output
)
text_encoder_output
=
run_text_encoder
(
config
[
"prompt"
],
text_encoders
,
config
,
image_encoder_output
)
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
set_target_shape
(
args
,
image_encoder_output
)
set_target_shape
(
config
,
image_encoder_output
)
scheduler
=
init_scheduler
(
args
,
image_encoder_output
)
scheduler
=
init_scheduler
(
config
,
image_encoder_output
)
model
.
set_scheduler
(
scheduler
)
model
.
set_scheduler
(
scheduler
)
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
latents
,
generator
=
run_main_inference
(
args
,
model
,
text_encoder_output
,
image_encoder_out
put
)
latents
,
generator
=
run_main_inference
(
model
,
in
put
s
)
if
args
.
cpu_offload
:
if
config
.
cpu_offload
:
scheduler
.
clear
()
scheduler
.
clear
()
del
text_encoder_output
,
image_encoder_output
,
model
,
text_encoders
,
scheduler
del
text_encoder_output
,
image_encoder_output
,
model
,
text_encoders
,
scheduler
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
with
time_duration
(
"Run VAE"
):
with
time_duration
(
"Run VAE"
):
images
=
run_vae
(
latents
,
generator
,
args
)
images
=
run_vae
(
latents
,
generator
,
config
)
if
not
args
.
parallel_attn_type
or
(
args
.
parallel_attn_type
and
dist
.
get_rank
()
==
0
):
if
not
config
.
parallel_attn_type
or
(
config
.
parallel_attn_type
and
dist
.
get_rank
()
==
0
):
with
time_duration
(
"Save video"
):
with
time_duration
(
"Save video"
):
if
args
.
model_cls
==
"wan2.1"
:
if
config
.
model_cls
==
"wan2.1"
:
cache_video
(
tensor
=
images
,
save_file
=
args
.
save_video_path
,
fps
=
16
,
nrow
=
1
,
normalize
=
True
,
value_range
=
(
-
1
,
1
))
cache_video
(
tensor
=
images
,
save_file
=
config
.
save_video_path
,
fps
=
16
,
nrow
=
1
,
normalize
=
True
,
value_range
=
(
-
1
,
1
))
else
:
else
:
save_videos_grid
(
images
,
args
.
save_video_path
,
fps
=
24
)
save_videos_grid
(
images
,
config
.
save_video_path
,
fps
=
24
)
end_time
=
time
.
time
()
end_time
=
time
.
time
()
print
(
f
"Total cost:
{
end_time
-
start_time
}
"
)
print
(
f
"Total cost:
{
end_time
-
start_time
}
"
)
lightx2v/text2v/models/networks/hunyuan/infer/post_infer.py
View file @
efb4d161
...
@@ -2,17 +2,20 @@ import torch
...
@@ -2,17 +2,20 @@ import torch
class
HunyuanPostInfer
:
class
HunyuanPostInfer
:
def
__init__
(
self
):
def
__init__
(
self
,
config
):
pass
self
.
config
=
config
def
infer
(
self
,
weights
,
img
,
vec
,
shape
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
def
infer
(
self
,
weights
,
img
,
vec
):
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
weights
.
final_layer_adaLN_modulation_1
.
apply
(
out
)
out
=
weights
.
final_layer_adaLN_modulation_1
.
apply
(
out
)
shift
,
scale
=
out
.
chunk
(
2
,
dim
=
1
)
shift
,
scale
=
out
.
chunk
(
2
,
dim
=
1
)
out
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
out
*
(
1
+
scale
)
+
shift
out
=
out
*
(
1
+
scale
)
+
shift
out
=
weights
.
final_layer_linear
.
apply
(
out
.
to
(
torch
.
float32
))
out
=
weights
.
final_layer_linear
.
apply
(
out
.
to
(
torch
.
float32
))
_
,
_
,
ot
,
oh
,
ow
=
shape
_
,
_
,
ot
,
oh
,
ow
=
self
.
scheduler
.
latents
.
shape
patch_size
=
[
1
,
2
,
2
]
patch_size
=
[
1
,
2
,
2
]
tt
,
th
,
tw
=
(
tt
,
th
,
tw
=
(
ot
//
patch_size
[
0
],
ot
//
patch_size
[
0
],
...
...
lightx2v/text2v/models/networks/hunyuan/infer/pre_infer.py
View file @
efb4d161
...
@@ -5,11 +5,25 @@ from lightx2v.attentions import attention
...
@@ -5,11 +5,25 @@ from lightx2v.attentions import attention
class
HunyuanPreInfer
:
class
HunyuanPreInfer
:
def
__init__
(
self
):
def
__init__
(
self
,
config
):
self
.
heads_num
=
24
self
.
heads_num
=
24
self
.
config
=
config
def
infer
(
self
,
weights
,
x
,
t
,
text_states
,
text_mask
,
text_states_2
,
freqs_cos
,
freqs_sin
,
guidance
,
img_latents
=
None
):
def
set_scheduler
(
self
,
scheduler
):
if
img_latents
is
not
None
:
self
.
scheduler
=
scheduler
def
infer
(
self
,
weights
,
inputs
):
x
=
self
.
scheduler
.
latents
t
=
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]
freqs_cos
=
self
.
scheduler
.
freqs_cos
freqs_sin
=
self
.
scheduler
.
freqs_sin
guidance
=
self
.
scheduler
.
guidance
text_states
=
inputs
[
"text_encoder_output"
][
"text_encoder_1_text_states"
]
text_mask
=
inputs
[
"text_encoder_output"
][
"text_encoder_1_attention_mask"
]
text_states_2
=
inputs
[
"text_encoder_output"
][
"text_encoder_2_text_states"
]
if
self
.
config
[
"task"
]
==
"i2v"
:
token_replace_t
=
torch
.
zeros_like
(
t
)
token_replace_t
=
torch
.
zeros_like
(
t
)
token_replace_vec
=
self
.
infer_time_in
(
weights
,
token_replace_t
)
token_replace_vec
=
self
.
infer_time_in
(
weights
,
token_replace_t
)
th
=
x
.
shape
[
-
2
]
//
2
th
=
x
.
shape
[
-
2
]
//
2
...
@@ -22,7 +36,7 @@ class HunyuanPreInfer:
...
@@ -22,7 +36,7 @@ class HunyuanPreInfer:
infer_vector_out
=
self
.
infer_vector_in
(
weights
,
text_states_2
)
infer_vector_out
=
self
.
infer_vector_in
(
weights
,
text_states_2
)
vec
=
time_out
+
infer_vector_out
vec
=
time_out
+
infer_vector_out
if
img_latents
is
not
None
:
if
self
.
config
[
"task"
]
==
"i2v"
:
token_replace_vec
=
token_replace_vec
+
infer_vector_out
token_replace_vec
=
token_replace_vec
+
infer_vector_out
guidance_out
=
self
.
infer_guidance_in
(
weights
,
guidance
)
guidance_out
=
self
.
infer_guidance_in
(
weights
,
guidance
)
...
@@ -43,7 +57,7 @@ class HunyuanPreInfer:
...
@@ -43,7 +57,7 @@ class HunyuanPreInfer:
cu_seqlens_qkv
[
2
*
i
+
2
]
=
s2
cu_seqlens_qkv
[
2
*
i
+
2
]
=
s2
max_seqlen_qkv
=
img_seq_len
+
txt_seq_len
max_seqlen_qkv
=
img_seq_len
+
txt_seq_len
if
img_latents
is
not
None
:
if
self
.
config
[
"task"
]
==
"i2v"
:
return
img_out
[
0
],
infer_text_out
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
(
freqs_cos
,
freqs_sin
),
token_replace_vec
,
frist_frame_token_num
return
img_out
[
0
],
infer_text_out
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
(
freqs_cos
,
freqs_sin
),
token_replace_vec
,
frist_frame_token_num
return
img_out
[
0
],
infer_text_out
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
(
freqs_cos
,
freqs_sin
)
return
img_out
[
0
],
infer_text_out
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
(
freqs_cos
,
freqs_sin
)
...
...
lightx2v/text2v/models/networks/hunyuan/model.py
View file @
efb4d161
...
@@ -69,12 +69,14 @@ class HunyuanModel:
...
@@ -69,12 +69,14 @@ class HunyuanModel:
self
.
transformer_weights
.
load_weights
(
weight_dict
)
self
.
transformer_weights
.
load_weights
(
weight_dict
)
def
_init_infer
(
self
):
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
()
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
()
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
self
.
pre_infer
.
set_scheduler
(
scheduler
)
self
.
post_infer
.
set_scheduler
(
scheduler
)
self
.
transformer_infer
.
set_scheduler
(
scheduler
)
self
.
transformer_infer
.
set_scheduler
(
scheduler
)
def
to_cpu
(
self
):
def
to_cpu
(
self
):
...
@@ -88,28 +90,18 @@ class HunyuanModel:
...
@@ -88,28 +90,18 @@ class HunyuanModel:
self
.
transformer_weights
.
to_cuda
()
self
.
transformer_weights
.
to_cuda
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
text_encoder_output
,
image_encoder_output
,
arg
s
):
def
infer
(
self
,
input
s
):
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cuda
()
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
)
self
.
scheduler
.
latents
,
inputs
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
*
inputs
)
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
],
self
.
scheduler
.
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
*
inputs
)
text_encoder_output
[
"text_encoder_1_text_states"
],
text_encoder_output
[
"text_encoder_1_attention_mask"
],
text_encoder_output
[
"text_encoder_2_text_states"
],
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
,
self
.
scheduler
.
guidance
,
img_latents
=
image_encoder_output
[
"img_latents"
]
if
"img_latents"
in
image_encoder_output
else
None
,
)
img
,
vec
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
*
pre_infer_out
)
self
.
scheduler
.
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
img
,
vec
,
self
.
scheduler
.
latents
.
shape
)
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
==
self
.
scheduler
.
num_steps
:
if
self
.
scheduler
.
cnt
==
self
.
scheduler
.
num_steps
:
...
...
lightx2v/text2v/models/networks/wan/infer/post_infer.py
View file @
efb4d161
...
@@ -8,6 +8,9 @@ class WanPostInfer:
...
@@ -8,6 +8,9 @@ class WanPostInfer:
self
.
out_dim
=
config
[
"out_dim"
]
self
.
out_dim
=
config
[
"out_dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
patch_size
=
(
1
,
2
,
2
)
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
):
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
):
e
=
(
weights
.
head_modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
(
weights
.
head_modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
norm_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
).
type_as
(
x
)
norm_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
).
type_as
(
x
)
...
...
lightx2v/text2v/models/networks/wan/infer/pre_infer.py
View file @
efb4d161
...
@@ -22,7 +22,20 @@ class WanPreInfer:
...
@@ -22,7 +22,20 @@ class WanPreInfer:
self
.
dim
=
config
[
"dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
text_len
=
config
[
"text_len"
]
self
.
text_len
=
config
[
"text_len"
]
def
infer
(
self
,
weights
,
x
,
t
,
context
,
seq_len
,
clip_fea
=
None
,
y
=
None
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
def
infer
(
self
,
weights
,
inputs
,
positive
):
x
=
[
self
.
scheduler
.
latents
]
t
=
torch
.
stack
([
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]])
if
positive
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
else
:
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
seq_len
=
self
.
scheduler
.
seq_len
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
y
=
[
inputs
[
"image_encoder_output"
][
"vae_encode_out"
]]
if
self
.
task
==
"i2v"
:
if
self
.
task
==
"i2v"
:
x
=
[
torch
.
cat
([
u
,
v
],
dim
=
0
)
for
u
,
v
in
zip
(
x
,
y
)]
x
=
[
torch
.
cat
([
u
,
v
],
dim
=
0
)
for
u
,
v
in
zip
(
x
,
y
)]
...
...
lightx2v/text2v/models/networks/wan/model.py
View file @
efb4d161
...
@@ -95,6 +95,8 @@ class WanModel:
...
@@ -95,6 +95,8 @@ class WanModel:
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
self
.
pre_infer
.
set_scheduler
(
scheduler
)
self
.
post_infer
.
set_scheduler
(
scheduler
)
self
.
transformer_infer
.
set_scheduler
(
scheduler
)
self
.
transformer_infer
.
set_scheduler
(
scheduler
)
def
to_cpu
(
self
):
def
to_cpu
(
self
):
...
@@ -108,24 +110,13 @@ class WanModel:
...
@@ -108,24 +110,13 @@ class WanModel:
self
.
transformer_weights
.
to_cuda
()
self
.
transformer_weights
.
to_cuda
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
text_encoders_output
,
image_encoder_output
,
args
):
def
infer
(
self
,
inputs
):
timestep
=
torch
.
stack
([
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]])
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cuda
()
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
self
.
pre_weight
,
[
self
.
scheduler
.
latents
],
timestep
,
text_encoders_output
[
"context"
],
self
.
scheduler
.
seq_len
,
image_encoder_output
[
"clip_encoder_out"
],
[
image_encoder_output
[
"vae_encode_out"
]],
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
...
@@ -133,16 +124,7 @@ class WanModel:
...
@@ -133,16 +124,7 @@ class WanModel:
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
cnt
=
0
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
self
.
pre_weight
,
[
self
.
scheduler
.
latents
],
timestep
,
text_encoders_output
[
"context_null"
],
self
.
scheduler
.
seq_len
,
image_encoder_output
[
"clip_encoder_out"
],
[
image_encoder_output
[
"vae_encode_out"
]],
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
...
@@ -151,7 +133,7 @@ class WanModel:
...
@@ -151,7 +133,7 @@ class WanModel:
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
args
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
config
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
...
...
lightx2v/text2v/models/text_encoders/hf/clip/model.py
View file @
efb4d161
...
@@ -23,8 +23,8 @@ class TextEncoderHFClipModel:
...
@@ -23,8 +23,8 @@ class TextEncoderHFClipModel:
self
.
model
=
self
.
model
.
to
(
"cuda"
)
self
.
model
=
self
.
model
.
to
(
"cuda"
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
text
,
args
):
def
infer
(
self
,
text
,
config
):
if
args
.
cpu_offload
:
if
config
.
cpu_offload
:
self
.
to_cuda
()
self
.
to_cuda
()
tokens
=
self
.
tokenizer
(
tokens
=
self
.
tokenizer
(
text
,
text
,
...
@@ -44,7 +44,7 @@ class TextEncoderHFClipModel:
...
@@ -44,7 +44,7 @@ class TextEncoderHFClipModel:
)
)
last_hidden_state
=
outputs
[
"pooler_output"
]
last_hidden_state
=
outputs
[
"pooler_output"
]
if
args
.
cpu_offload
:
if
config
.
cpu_offload
:
self
.
to_cpu
()
self
.
to_cpu
()
return
last_hidden_state
,
tokens
[
"attention_mask"
]
return
last_hidden_state
,
tokens
[
"attention_mask"
]
...
...
lightx2v/text2v/models/text_encoders/hf/llama/model.py
View file @
efb4d161
...
@@ -34,8 +34,8 @@ class TextEncoderHFLlamaModel:
...
@@ -34,8 +34,8 @@ class TextEncoderHFLlamaModel:
self
.
model
=
self
.
model
.
to
(
"cuda"
)
self
.
model
=
self
.
model
.
to
(
"cuda"
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
text
,
args
):
def
infer
(
self
,
text
,
config
):
if
args
.
cpu_offload
:
if
config
.
cpu_offload
:
self
.
to_cuda
()
self
.
to_cuda
()
text
=
self
.
prompt_template
.
format
(
text
)
text
=
self
.
prompt_template
.
format
(
text
)
tokens
=
self
.
tokenizer
(
tokens
=
self
.
tokenizer
(
...
@@ -57,7 +57,7 @@ class TextEncoderHFLlamaModel:
...
@@ -57,7 +57,7 @@ class TextEncoderHFLlamaModel:
last_hidden_state
=
outputs
.
hidden_states
[
-
(
self
.
hidden_state_skip_layer
+
1
)][:,
self
.
crop_start
:]
last_hidden_state
=
outputs
.
hidden_states
[
-
(
self
.
hidden_state_skip_layer
+
1
)][:,
self
.
crop_start
:]
attention_mask
=
tokens
[
"attention_mask"
][:,
self
.
crop_start
:]
attention_mask
=
tokens
[
"attention_mask"
][:,
self
.
crop_start
:]
if
args
.
cpu_offload
:
if
config
.
cpu_offload
:
self
.
to_cpu
()
self
.
to_cpu
()
return
last_hidden_state
,
attention_mask
return
last_hidden_state
,
attention_mask
...
...
lightx2v/text2v/models/text_encoders/hf/llava/model.py
View file @
efb4d161
...
@@ -98,9 +98,9 @@ class TextEncoderHFLlavaModel:
...
@@ -98,9 +98,9 @@ class TextEncoderHFLlavaModel:
self
.
model
=
self
.
model
.
to
(
"cuda"
)
self
.
model
=
self
.
model
.
to
(
"cuda"
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
text
,
img
,
args
):
def
infer
(
self
,
text
,
img
,
config
):
#
if
args
.cpu_offload:
if
config
.
cpu_offload
:
#
self.to_cuda()
self
.
to_cuda
()
text
=
self
.
prompt_template
.
format
(
text
)
text
=
self
.
prompt_template
.
format
(
text
)
print
(
f
"text:
{
text
}
"
)
print
(
f
"text:
{
text
}
"
)
tokens
=
self
.
tokenizer
(
tokens
=
self
.
tokenizer
(
...
@@ -148,8 +148,8 @@ class TextEncoderHFLlavaModel:
...
@@ -148,8 +148,8 @@ class TextEncoderHFLlavaModel:
last_hidden_state
=
torch
.
cat
([
image_last_hidden_state
,
text_last_hidden_state
],
dim
=
1
)
last_hidden_state
=
torch
.
cat
([
image_last_hidden_state
,
text_last_hidden_state
],
dim
=
1
)
attention_mask
=
torch
.
cat
([
image_attention_mask
,
text_attention_mask
],
dim
=
1
)
attention_mask
=
torch
.
cat
([
image_attention_mask
,
text_attention_mask
],
dim
=
1
)
#
if
args
.cpu_offload:
if
config
.
cpu_offload
:
#
self.to_cpu()
self
.
to_cpu
()
return
last_hidden_state
,
attention_mask
return
last_hidden_state
,
attention_mask
...
...
lightx2v/text2v/models/text_encoders/hf/t5/model.py
View file @
efb4d161
...
@@ -492,8 +492,8 @@ class T5EncoderModel:
...
@@ -492,8 +492,8 @@ class T5EncoderModel:
def
to_cuda
(
self
):
def
to_cuda
(
self
):
self
.
model
=
self
.
model
.
to
(
"cuda"
)
self
.
model
=
self
.
model
.
to
(
"cuda"
)
def
infer
(
self
,
texts
,
args
):
def
infer
(
self
,
texts
,
config
):
if
args
.
cpu_offload
:
if
config
.
cpu_offload
:
self
.
to_cuda
()
self
.
to_cuda
()
ids
,
mask
=
self
.
tokenizer
(
texts
,
return_mask
=
True
,
add_special_tokens
=
True
)
ids
,
mask
=
self
.
tokenizer
(
texts
,
return_mask
=
True
,
add_special_tokens
=
True
)
...
@@ -502,7 +502,7 @@ class T5EncoderModel:
...
@@ -502,7 +502,7 @@ class T5EncoderModel:
seq_lens
=
mask
.
gt
(
0
).
sum
(
dim
=
1
).
long
()
seq_lens
=
mask
.
gt
(
0
).
sum
(
dim
=
1
).
long
()
context
=
self
.
model
(
ids
,
mask
)
context
=
self
.
model
(
ids
,
mask
)
if
args
.
cpu_offload
:
if
config
.
cpu_offload
:
self
.
to_cpu
()
self
.
to_cpu
()
return
[
u
[:
v
]
for
u
,
v
in
zip
(
context
,
seq_lens
)]
return
[
u
[:
v
]
for
u
,
v
in
zip
(
context
,
seq_lens
)]
...
...
lightx2v/text2v/models/video_encoders/hf/autoencoder_kl_causal_3d/model.py
View file @
efb4d161
...
@@ -4,15 +4,15 @@ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D, DiagonalGaussianDis
...
@@ -4,15 +4,15 @@ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D, DiagonalGaussianDis
class
VideoEncoderKLCausal3DModel
:
class
VideoEncoderKLCausal3DModel
:
def
__init__
(
self
,
model_path
,
dtype
,
device
,
args
):
def
__init__
(
self
,
model_path
,
dtype
,
device
,
config
):
self
.
model_path
=
model_path
self
.
model_path
=
model_path
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
self
.
args
=
args
self
.
config
=
config
self
.
load
()
self
.
load
()
def
load
(
self
):
def
load
(
self
):
if
self
.
args
.
task
==
"t2v"
:
if
self
.
config
.
task
==
"t2v"
:
self
.
vae_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-t2v-720p/vae"
)
self
.
vae_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-t2v-720p/vae"
)
else
:
else
:
self
.
vae_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-i2v-720p/vae"
)
self
.
vae_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-i2v-720p/vae"
)
...
@@ -30,8 +30,8 @@ class VideoEncoderKLCausal3DModel:
...
@@ -30,8 +30,8 @@ class VideoEncoderKLCausal3DModel:
def
to_cuda
(
self
):
def
to_cuda
(
self
):
self
.
model
=
self
.
model
.
to
(
"cuda"
)
self
.
model
=
self
.
model
.
to
(
"cuda"
)
def
decode
(
self
,
latents
,
generator
,
args
):
def
decode
(
self
,
latents
,
generator
,
config
):
if
args
.
cpu_offload
:
if
config
.
cpu_offload
:
self
.
to_cuda
()
self
.
to_cuda
()
latents
=
latents
/
self
.
model
.
config
.
scaling_factor
latents
=
latents
/
self
.
model
.
config
.
scaling_factor
latents
=
latents
.
to
(
dtype
=
self
.
dtype
,
device
=
torch
.
device
(
"cuda"
))
latents
=
latents
.
to
(
dtype
=
self
.
dtype
,
device
=
torch
.
device
(
"cuda"
))
...
@@ -39,11 +39,11 @@ class VideoEncoderKLCausal3DModel:
...
@@ -39,11 +39,11 @@ class VideoEncoderKLCausal3DModel:
image
=
self
.
model
.
decode
(
latents
,
return_dict
=
False
,
generator
=
generator
)[
0
]
image
=
self
.
model
.
decode
(
latents
,
return_dict
=
False
,
generator
=
generator
)[
0
]
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
float
()
image
=
image
.
cpu
().
float
()
if
args
.
cpu_offload
:
if
config
.
cpu_offload
:
self
.
to_cpu
()
self
.
to_cpu
()
return
image
return
image
def
encode
(
self
,
x
,
args
):
def
encode
(
self
,
x
,
config
):
h
=
self
.
model
.
encoder
(
x
)
h
=
self
.
model
.
encoder
(
x
)
moments
=
self
.
model
.
quant_conv
(
h
)
moments
=
self
.
model
.
quant_conv
(
h
)
posterior
=
DiagonalGaussianDistribution
(
moments
)
posterior
=
DiagonalGaussianDistribution
(
moments
)
...
...
lightx2v/text2v/models/video_encoders/hf/wan/vae.py
View file @
efb4d161
...
@@ -786,8 +786,8 @@ class WanVAE:
...
@@ -786,8 +786,8 @@ class WanVAE:
return
images
return
images
def
decode
(
self
,
zs
,
generator
,
args
):
def
decode
(
self
,
zs
,
generator
,
config
):
if
args
.
cpu_offload
:
if
config
.
cpu_offload
:
self
.
to_cuda
()
self
.
to_cuda
()
if
self
.
parallel
:
if
self
.
parallel
:
...
@@ -806,7 +806,7 @@ class WanVAE:
...
@@ -806,7 +806,7 @@ class WanVAE:
else
:
else
:
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
if
args
.
cpu_offload
:
if
config
.
cpu_offload
:
images
=
images
.
cpu
().
float
()
images
=
images
.
cpu
().
float
()
self
.
to_cpu
()
self
.
to_cpu
()
...
...
lightx2v/utils/set_config.py
0 → 100644
View file @
efb4d161
import
json
from
easydict
import
EasyDict
def
set_config
(
args
):
config
=
{
k
:
v
for
k
,
v
in
vars
(
args
).
items
()}
config
=
EasyDict
(
config
)
if
args
.
mm_config
:
config
.
mm_config
=
json
.
loads
(
args
.
mm_config
)
else
:
config
.
mm_config
=
None
if
args
.
config_path
is
not
None
:
with
open
(
args
.
config_path
,
"r"
)
as
f
:
model_config
=
json
.
load
(
f
)
config
.
update
(
model_config
)
return
config
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment