Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
b50498fa
Unverified
Commit
b50498fa
authored
Dec 02, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Dec 02, 2025
Browse files
Add lightx2v_platform (#541)
parent
31da6925
Changes
75
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
90 additions
and
105 deletions
+90
-105
lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
...tx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
+6
-6
lightx2v/models/input_encoders/hf/wan/t5/model.py
lightx2v/models/input_encoders/hf/wan/t5/model.py
+4
-5
lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
+5
-4
lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
+2
-2
lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
.../models/networks/hunyuan_video/infer/transformer_infer.py
+2
-2
lightx2v/models/networks/hunyuan_video/model.py
lightx2v/models/networks/hunyuan_video/model.py
+4
-4
lightx2v/models/networks/qwen_image/infer/transformer_infer.py
...x2v/models/networks/qwen_image/infer/transformer_infer.py
+1
-1
lightx2v/models/networks/qwen_image/model.py
lightx2v/models/networks/qwen_image/model.py
+6
-5
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+2
-2
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+0
-2
lightx2v/models/networks/wan/infer/triton_ops.py
lightx2v/models/networks/wan/infer/triton_ops.py
+1
-1
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+4
-5
lightx2v/models/runners/base_runner.py
lightx2v/models/runners/base_runner.py
+4
-2
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+2
-2
lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
...v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
+16
-17
lightx2v/models/runners/qwen_image/qwen_image_runner.py
lightx2v/models/runners/qwen_image/qwen_image_runner.py
+8
-13
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+9
-12
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+8
-11
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
+1
-2
lightx2v/models/schedulers/hunyuan_video/scheduler.py
lightx2v/models/schedulers/hunyuan_video/scheduler.py
+5
-7
No files found.
lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
View file @
b50498fa
...
@@ -2,18 +2,18 @@ import torch
...
@@ -2,18 +2,18 @@ import torch
from
transformers
import
AutoFeatureExtractor
,
AutoModel
from
transformers
import
AutoFeatureExtractor
,
AutoModel
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
SekoAudioEncoderModel
:
class
SekoAudioEncoderModel
:
def
__init__
(
self
,
model_path
,
audio_sr
,
cpu_offload
,
run_device
):
def
__init__
(
self
,
model_path
,
audio_sr
,
cpu_offload
):
self
.
model_path
=
model_path
self
.
model_path
=
model_path
self
.
audio_sr
=
audio_sr
self
.
audio_sr
=
audio_sr
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device
=
torch
.
device
(
"cpu"
)
else
:
else
:
self
.
device
=
torch
.
device
(
run_device
)
self
.
device
=
torch
.
device
(
AI_DEVICE
)
self
.
run_device
=
run_device
self
.
load
()
self
.
load
()
def
load
(
self
):
def
load
(
self
):
...
@@ -27,13 +27,13 @@ class SekoAudioEncoderModel:
...
@@ -27,13 +27,13 @@ class SekoAudioEncoderModel:
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
"cpu"
)
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
"cpu"
)
def
to_cuda
(
self
):
def
to_cuda
(
self
):
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
self
.
run_device
)
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
AI_DEVICE
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
audio_segment
):
def
infer
(
self
,
audio_segment
):
audio_feat
=
self
.
audio_feature_extractor
(
audio_segment
,
sampling_rate
=
self
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
to
(
self
.
run_device
).
to
(
dtype
=
GET_DTYPE
())
audio_feat
=
self
.
audio_feature_extractor
(
audio_segment
,
sampling_rate
=
self
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
to
(
AI_DEVICE
).
to
(
dtype
=
GET_DTYPE
())
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
self
.
run_device
)
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
AI_DEVICE
)
audio_feat
=
self
.
audio_feature_encoder
(
audio_feat
,
return_dict
=
True
).
last_hidden_state
audio_feat
=
self
.
audio_feature_encoder
(
audio_feat
,
return_dict
=
True
).
last_hidden_state
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
"cpu"
)
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
"cpu"
)
...
...
lightx2v/models/input_encoders/hf/wan/t5/model.py
View file @
b50498fa
...
@@ -24,8 +24,8 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
...
@@ -24,8 +24,8 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
SglQuantLinearFp8
,
# noqa E402
SglQuantLinearFp8
,
# noqa E402
TorchaoQuantLinearInt8
,
# noqa E402
TorchaoQuantLinearInt8
,
# noqa E402
VllmQuantLinearInt8
,
# noqa E402,
VllmQuantLinearInt8
,
# noqa E402,
MluQuantLinearInt8
,
)
)
from
lightx2v_platform.ops.mm.cambricon_mlu.q_linear
import
MluQuantLinearInt8
# noqa E402
from
lightx2v.models.input_encoders.hf.wan.t5.tokenizer
import
HuggingfaceTokenizer
# noqa E402
from
lightx2v.models.input_encoders.hf.wan.t5.tokenizer
import
HuggingfaceTokenizer
# noqa E402
from
lightx2v.utils.envs
import
*
# noqa E402
from
lightx2v.utils.envs
import
*
# noqa E402
from
lightx2v.utils.registry_factory
import
(
# noqa E402
from
lightx2v.utils.registry_factory
import
(
# noqa E402
...
@@ -34,6 +34,7 @@ from lightx2v.utils.registry_factory import ( # noqa E402
...
@@ -34,6 +34,7 @@ from lightx2v.utils.registry_factory import ( # noqa E402
RMS_WEIGHT_REGISTER
,
# noqa E402
RMS_WEIGHT_REGISTER
,
# noqa E402
)
)
from
lightx2v.utils.utils
import
load_weights
# noqa E402
from
lightx2v.utils.utils
import
load_weights
# noqa E402
from
lightx2v_platform.base.global_var
import
AI_DEVICE
# noqa E402
__all__
=
[
__all__
=
[
"T5Model"
,
"T5Model"
,
...
@@ -745,7 +746,6 @@ class T5EncoderModel:
...
@@ -745,7 +746,6 @@ class T5EncoderModel:
text_len
,
text_len
,
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
),
device
=
torch
.
device
(
"cuda"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
checkpoint_path
=
None
,
tokenizer_path
=
None
,
tokenizer_path
=
None
,
shard_fn
=
None
,
shard_fn
=
None
,
...
@@ -758,7 +758,6 @@ class T5EncoderModel:
...
@@ -758,7 +758,6 @@ class T5EncoderModel:
self
.
text_len
=
text_len
self
.
text_len
=
text_len
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
self
.
run_device
=
run_device
if
t5_quantized_ckpt
is
not
None
and
t5_quantized
:
if
t5_quantized_ckpt
is
not
None
and
t5_quantized
:
self
.
checkpoint_path
=
t5_quantized_ckpt
self
.
checkpoint_path
=
t5_quantized_ckpt
else
:
else
:
...
@@ -807,8 +806,8 @@ class T5EncoderModel:
...
@@ -807,8 +806,8 @@ class T5EncoderModel:
def
infer
(
self
,
texts
):
def
infer
(
self
,
texts
):
ids
,
mask
=
self
.
tokenizer
(
texts
,
return_mask
=
True
,
add_special_tokens
=
True
)
ids
,
mask
=
self
.
tokenizer
(
texts
,
return_mask
=
True
,
add_special_tokens
=
True
)
ids
=
ids
.
to
(
self
.
run_device
)
ids
=
ids
.
to
(
AI_DEVICE
)
mask
=
mask
.
to
(
self
.
run_device
)
mask
=
mask
.
to
(
AI_DEVICE
)
seq_lens
=
mask
.
gt
(
0
).
sum
(
dim
=
1
).
long
()
seq_lens
=
mask
.
gt
(
0
).
sum
(
dim
=
1
).
long
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
...
lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
View file @
b50498fa
...
@@ -10,8 +10,10 @@ from loguru import logger
...
@@ -10,8 +10,10 @@ from loguru import logger
# from lightx2v.attentions import attention
# from lightx2v.attentions import attention
from
lightx2v.common.ops.attn
import
TorchSDPAWeight
from
lightx2v.common.ops.attn
import
TorchSDPAWeight
from
lightx2v.models.input_encoders.hf.q_linear
import
MluQuantLinearInt8
,
Q8FQuantLinearFp8
,
Q8FQuantLinearInt8
,
SglQuantLinearFp8
,
TorchaoQuantLinearInt8
,
VllmQuantLinearInt8
from
lightx2v.models.input_encoders.hf.q_linear
import
Q8FQuantLinearFp8
,
Q8FQuantLinearInt8
,
SglQuantLinearFp8
,
TorchaoQuantLinearInt8
,
VllmQuantLinearInt8
from
lightx2v.utils.utils
import
load_weights
from
lightx2v.utils.utils
import
load_weights
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
lightx2v_platform.ops.mm.cambricon_mlu.q_linear
import
MluQuantLinearInt8
__all__
=
[
__all__
=
[
"XLMRobertaCLIP"
,
"XLMRobertaCLIP"
,
...
@@ -426,9 +428,8 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
...
@@ -426,9 +428,8 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class
CLIPModel
:
class
CLIPModel
:
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
cpu_offload
=
False
,
use_31_block
=
True
,
load_from_rank0
=
False
,
run_device
=
torch
.
device
(
"cuda"
)
):
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
cpu_offload
=
False
,
use_31_block
=
True
,
load_from_rank0
=
False
):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
run_device
=
run_device
self
.
quantized
=
clip_quantized
self
.
quantized
=
clip_quantized
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
self
.
use_31_block
=
use_31_block
self
.
use_31_block
=
use_31_block
...
@@ -462,7 +463,7 @@ class CLIPModel:
...
@@ -462,7 +463,7 @@ class CLIPModel:
return
out
return
out
def
to_cuda
(
self
):
def
to_cuda
(
self
):
self
.
model
=
self
.
model
.
to
(
self
.
run_device
)
self
.
model
=
self
.
model
.
to
(
AI_DEVICE
)
def
to_cpu
(
self
):
def
to_cpu
(
self
):
self
.
model
=
self
.
model
.
cpu
()
self
.
model
=
self
.
model
.
cpu
()
lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
View file @
b50498fa
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
einops
import
rearrange
from
einops
import
rearrange
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.attn_no_pad
import
flash_attn_no_pad
,
flash_attn_no_pad_v3
,
sage_attn_no_pad_v2
from
.attn_no_pad
import
flash_attn_no_pad
,
flash_attn_no_pad_v3
,
sage_attn_no_pad_v2
from
.module_io
import
HunyuanVideo15InferModuleOutput
from
.module_io
import
HunyuanVideo15InferModuleOutput
...
@@ -68,7 +69,6 @@ class HunyuanVideo15PreInfer:
...
@@ -68,7 +69,6 @@ class HunyuanVideo15PreInfer:
self
.
heads_num
=
config
[
"heads_num"
]
self
.
heads_num
=
config
[
"heads_num"
]
self
.
frequency_embedding_size
=
256
self
.
frequency_embedding_size
=
256
self
.
max_period
=
10000
self
.
max_period
=
10000
self
.
run_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
...
@@ -155,7 +155,7 @@ class HunyuanVideo15PreInfer:
...
@@ -155,7 +155,7 @@ class HunyuanVideo15PreInfer:
byt5_txt
=
byt5_txt
+
weights
.
cond_type_embedding
.
apply
(
torch
.
ones_like
(
byt5_txt
[:,
:,
0
],
device
=
byt5_txt
.
device
,
dtype
=
torch
.
long
))
byt5_txt
=
byt5_txt
+
weights
.
cond_type_embedding
.
apply
(
torch
.
ones_like
(
byt5_txt
[:,
:,
0
],
device
=
byt5_txt
.
device
,
dtype
=
torch
.
long
))
txt
,
text_mask
=
self
.
reorder_txt_token
(
byt5_txt
,
txt
,
byt5_text_mask
,
text_mask
,
zero_feat
=
True
)
txt
,
text_mask
=
self
.
reorder_txt_token
(
byt5_txt
,
txt
,
byt5_text_mask
,
text_mask
,
zero_feat
=
True
)
siglip_output
=
siglip_output
+
weights
.
cond_type_embedding
.
apply
(
2
*
torch
.
ones_like
(
siglip_output
[:,
:,
0
],
dtype
=
torch
.
long
,
device
=
self
.
run_device
))
siglip_output
=
siglip_output
+
weights
.
cond_type_embedding
.
apply
(
2
*
torch
.
ones_like
(
siglip_output
[:,
:,
0
],
dtype
=
torch
.
long
,
device
=
AI_DEVICE
))
txt
,
text_mask
=
self
.
reorder_txt_token
(
siglip_output
,
txt
,
siglip_mask
,
text_mask
)
txt
,
text_mask
=
self
.
reorder_txt_token
(
siglip_output
,
txt
,
siglip_mask
,
text_mask
)
txt
=
txt
[:,
:
text_mask
.
sum
(),
:]
txt
=
txt
[:,
:
text_mask
.
sum
(),
:]
...
...
lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
View file @
b50498fa
...
@@ -10,6 +10,7 @@ except Exception as e:
...
@@ -10,6 +10,7 @@ except Exception as e:
apply_rope_with_cos_sin_cache_inplace
=
None
apply_rope_with_cos_sin_cache_inplace
=
None
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.module_io
import
HunyuanVideo15ImgBranchOutput
,
HunyuanVideo15TxtBranchOutput
from
.module_io
import
HunyuanVideo15ImgBranchOutput
,
HunyuanVideo15TxtBranchOutput
from
.triton_ops
import
fuse_scale_shift_kernel
from
.triton_ops
import
fuse_scale_shift_kernel
...
@@ -100,7 +101,6 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
...
@@ -100,7 +101,6 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
self
.
config
=
config
self
.
config
=
config
self
.
double_blocks_num
=
config
[
"mm_double_blocks_depth"
]
self
.
double_blocks_num
=
config
[
"mm_double_blocks_depth"
]
self
.
heads_num
=
config
[
"heads_num"
]
self
.
heads_num
=
config
[
"heads_num"
]
self
.
run_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
if
self
.
config
[
"seq_parallel"
]:
if
self
.
config
[
"seq_parallel"
]:
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
self
.
seq_p_fp8_comm
=
self
.
config
[
"parallel"
].
get
(
"seq_p_fp8_comm"
,
False
)
self
.
seq_p_fp8_comm
=
self
.
config
[
"parallel"
].
get
(
"seq_p_fp8_comm"
,
False
)
...
@@ -222,7 +222,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
...
@@ -222,7 +222,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
key
=
torch
.
cat
([
img_k
,
txt_k
],
dim
=
1
)
key
=
torch
.
cat
([
img_k
,
txt_k
],
dim
=
1
)
value
=
torch
.
cat
([
img_v
,
txt_v
],
dim
=
1
)
value
=
torch
.
cat
([
img_v
,
txt_v
],
dim
=
1
)
seqlen
=
query
.
shape
[
1
]
seqlen
=
query
.
shape
[
1
]
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
seqlen
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
self
.
run_device
,
non_blocking
=
True
)
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
seqlen
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
AI_DEVICE
,
non_blocking
=
True
)
if
self
.
config
[
"seq_parallel"
]:
if
self
.
config
[
"seq_parallel"
]:
attn_out
=
weights
.
self_attention_parallel
.
apply
(
attn_out
=
weights
.
self_attention_parallel
.
apply
(
...
...
lightx2v/models/networks/hunyuan_video/model.py
View file @
b50498fa
...
@@ -176,12 +176,12 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
...
@@ -176,12 +176,12 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
self
.
device
.
type
!=
"cpu"
and
dist
.
is_initialized
()
:
if
self
.
config
[
"parallel"
]
:
device
=
torch
.
device
(
"{}:{}"
.
format
(
self
.
device
.
type
,
dist
.
get_rank
()
))
device
=
dist
.
get_rank
()
else
:
else
:
device
=
self
.
device
device
=
str
(
self
.
device
)
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
str
(
device
)
)
as
f
:
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
device
)
as
f
:
return
{
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
for
key
in
f
.
keys
()
for
key
in
f
.
keys
()
...
...
lightx2v/models/networks/qwen_image/infer/transformer_infer.py
View file @
b50498fa
...
@@ -111,7 +111,7 @@ def apply_attn(block_weight, hidden_states, encoder_hidden_states, image_rotary_
...
@@ -111,7 +111,7 @@ def apply_attn(block_weight, hidden_states, encoder_hidden_states, image_rotary_
if
attn_type
==
"torch_sdpa"
:
if
attn_type
==
"torch_sdpa"
:
joint_hidden_states
=
block_weight
.
attn
.
calculate
.
apply
(
q
=
joint_query
,
k
=
joint_key
,
v
=
joint_value
)
joint_hidden_states
=
block_weight
.
attn
.
calculate
.
apply
(
q
=
joint_query
,
k
=
joint_key
,
v
=
joint_value
)
el
if
attn_type
in
[
"flash_attn3"
,
"sage_attn2"
,
"mlu_flash_attn"
,
"flash_attn2"
,
"mlu_sage_attn"
]
:
el
se
:
joint_query
=
joint_query
.
squeeze
(
0
)
joint_query
=
joint_query
.
squeeze
(
0
)
joint_key
=
joint_key
.
squeeze
(
0
)
joint_key
=
joint_key
.
squeeze
(
0
)
joint_value
=
joint_value
.
squeeze
(
0
)
joint_value
=
joint_value
.
squeeze
(
0
)
...
...
lightx2v/models/networks/qwen_image/model.py
View file @
b50498fa
...
@@ -8,6 +8,7 @@ from safetensors import safe_open
...
@@ -8,6 +8,7 @@ from safetensors import safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.infer.offload.transformer_infer
import
QwenImageOffloadTransformerInfer
from
.infer.offload.transformer_infer
import
QwenImageOffloadTransformerInfer
from
.infer.post_infer
import
QwenImagePostInfer
from
.infer.post_infer
import
QwenImagePostInfer
...
@@ -28,7 +29,7 @@ class QwenImageTransformerModel:
...
@@ -28,7 +29,7 @@ class QwenImageTransformerModel:
self
.
model_path
=
os
.
path
.
join
(
config
[
"model_path"
],
"transformer"
)
self
.
model_path
=
os
.
path
.
join
(
config
[
"model_path"
],
"transformer"
)
self
.
cpu_offload
=
config
.
get
(
"cpu_offload"
,
False
)
self
.
cpu_offload
=
config
.
get
(
"cpu_offload"
,
False
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
device
=
torch
.
device
(
"cpu"
)
if
self
.
cpu_offload
else
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
self
.
device
=
torch
.
device
(
"cpu"
)
if
self
.
cpu_offload
else
torch
.
device
(
AI_DEVICE
)
with
open
(
os
.
path
.
join
(
config
[
"model_path"
],
"transformer"
,
"config.json"
),
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
config
[
"model_path"
],
"transformer"
,
"config.json"
),
"r"
)
as
f
:
transformer_config
=
json
.
load
(
f
)
transformer_config
=
json
.
load
(
f
)
...
@@ -124,12 +125,12 @@ class QwenImageTransformerModel:
...
@@ -124,12 +125,12 @@ class QwenImageTransformerModel:
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
self
.
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]
and
dist
.
is_initialized
()
:
if
self
.
config
[
"parallel"
]
:
device
=
torch
.
device
(
"{}:{}"
.
format
(
self
.
device
.
type
,
dist
.
get_rank
()
))
device
=
dist
.
get_rank
()
else
:
else
:
device
=
self
.
device
device
=
str
(
self
.
device
)
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
str
(
device
)
)
as
f
:
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
device
)
as
f
:
return
{
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
for
key
in
f
.
keys
()
for
key
in
f
.
keys
()
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
b50498fa
...
@@ -13,6 +13,7 @@ from lightx2v.models.networks.wan.weights.audio.transformer_weights import WanAu
...
@@ -13,6 +13,7 @@ from lightx2v.models.networks.wan.weights.audio.transformer_weights import WanAu
from
lightx2v.models.networks.wan.weights.post_weights
import
WanPostWeights
from
lightx2v.models.networks.wan.weights.post_weights
import
WanPostWeights
from
lightx2v.models.networks.wan.weights.pre_weights
import
WanPreWeights
from
lightx2v.models.networks.wan.weights.pre_weights
import
WanPreWeights
from
lightx2v.utils.utils
import
load_weights
from
lightx2v.utils.utils
import
load_weights
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanAudioModel
(
WanModel
):
class
WanAudioModel
(
WanModel
):
...
@@ -22,7 +23,6 @@ class WanAudioModel(WanModel):
...
@@ -22,7 +23,6 @@ class WanAudioModel(WanModel):
def
__init__
(
self
,
model_path
,
config
,
device
):
def
__init__
(
self
,
model_path
,
config
,
device
):
self
.
config
=
config
self
.
config
=
config
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
_load_adapter_ckpt
()
self
.
_load_adapter_ckpt
()
super
().
__init__
(
model_path
,
config
,
device
)
super
().
__init__
(
model_path
,
config
,
device
)
...
@@ -51,7 +51,7 @@ class WanAudioModel(WanModel):
...
@@ -51,7 +51,7 @@ class WanAudioModel(WanModel):
if
not
adapter_offload
:
if
not
adapter_offload
:
if
not
dist
.
is_initialized
()
or
not
load_from_rank0
:
if
not
dist
.
is_initialized
()
or
not
load_from_rank0
:
for
key
in
self
.
adapter_weights_dict
:
for
key
in
self
.
adapter_weights_dict
:
self
.
adapter_weights_dict
[
key
]
=
self
.
adapter_weights_dict
[
key
].
to
(
torch
.
device
(
self
.
run_device
))
self
.
adapter_weights_dict
[
key
]
=
self
.
adapter_weights_dict
[
key
].
to
(
torch
.
device
(
AI_DEVICE
))
def
_init_infer_class
(
self
):
def
_init_infer_class
(
self
):
super
().
_init_infer_class
()
super
().
_init_infer_class
()
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
b50498fa
...
@@ -10,10 +10,8 @@ class WanPreInfer:
...
@@ -10,10 +10,8 @@ class WanPreInfer:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
self
.
config
=
config
self
.
config
=
config
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
task
=
config
[
"task"
]
self
.
task
=
config
[
"task"
]
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
enable_dynamic_cfg
=
config
.
get
(
"enable_dynamic_cfg"
,
False
)
self
.
enable_dynamic_cfg
=
config
.
get
(
"enable_dynamic_cfg"
,
False
)
...
...
lightx2v/models/networks/wan/infer/triton_ops.py
View file @
b50498fa
...
@@ -124,7 +124,7 @@ def fuse_scale_shift_kernel(
...
@@ -124,7 +124,7 @@ def fuse_scale_shift_kernel(
block_l
:
int
=
128
,
block_l
:
int
=
128
,
block_c
:
int
=
128
,
block_c
:
int
=
128
,
):
):
assert
x
.
is_cuda
and
scale
.
is_cuda
#
assert x.is_cuda and scale.is_cuda
assert
x
.
is_contiguous
()
assert
x
.
is_contiguous
()
if
x
.
dim
()
==
2
:
if
x
.
dim
()
==
2
:
x
=
x
.
unsqueeze
(
0
)
x
=
x
.
unsqueeze
(
0
)
...
...
lightx2v/models/networks/wan/model.py
View file @
b50498fa
...
@@ -44,7 +44,6 @@ class WanModel(CompiledMethodsMixin):
...
@@ -44,7 +44,6 @@ class WanModel(CompiledMethodsMixin):
super
().
__init__
()
super
().
__init__
()
self
.
model_path
=
model_path
self
.
model_path
=
model_path
self
.
config
=
config
self
.
config
=
config
self
.
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
model_type
=
model_type
self
.
model_type
=
model_type
...
@@ -147,12 +146,12 @@ class WanModel(CompiledMethodsMixin):
...
@@ -147,12 +146,12 @@ class WanModel(CompiledMethodsMixin):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
(
self
.
device
.
type
==
"cuda"
or
self
.
device
.
type
==
"mlu"
)
and
dist
.
is_initialized
()
:
if
self
.
config
[
"parallel"
]
:
device
=
torch
.
device
(
"{}:{}"
.
format
(
self
.
device
.
type
,
dist
.
get_rank
()
))
device
=
dist
.
get_rank
()
else
:
else
:
device
=
self
.
device
device
=
str
(
self
.
device
)
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
str
(
device
)
)
as
f
:
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
device
)
as
f
:
return
{
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
for
key
in
f
.
keys
()
for
key
in
f
.
keys
()
...
...
lightx2v/models/runners/base_runner.py
View file @
b50498fa
...
@@ -3,6 +3,8 @@ from abc import ABC
...
@@ -3,6 +3,8 @@ from abc import ABC
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
BaseRunner
(
ABC
):
class
BaseRunner
(
ABC
):
"""Abstract base class for all Runners
"""Abstract base class for all Runners
...
@@ -145,9 +147,9 @@ class BaseRunner(ABC):
...
@@ -145,9 +147,9 @@ class BaseRunner(ABC):
if
world_size
>
1
:
if
world_size
>
1
:
if
rank
==
signal_rank
:
if
rank
==
signal_rank
:
t
=
torch
.
tensor
([
stopped
],
dtype
=
torch
.
int32
).
to
(
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
t
=
torch
.
tensor
([
stopped
],
dtype
=
torch
.
int32
).
to
(
device
=
AI_DEVICE
)
else
:
else
:
t
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
t
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
AI_DEVICE
)
dist
.
broadcast
(
t
,
src
=
signal_rank
)
dist
.
broadcast
(
t
,
src
=
signal_rank
)
stopped
=
t
.
item
()
stopped
=
t
.
item
()
...
...
lightx2v/models/runners/default_runner.py
View file @
b50498fa
...
@@ -15,6 +15,7 @@ from lightx2v.utils.global_paras import CALIB
...
@@ -15,6 +15,7 @@ from lightx2v.utils.global_paras import CALIB
from
lightx2v.utils.memory_profiler
import
peak_memory_decorator
from
lightx2v.utils.memory_profiler
import
peak_memory_decorator
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.base_runner
import
BaseRunner
from
.base_runner
import
BaseRunner
...
@@ -59,11 +60,10 @@ class DefaultRunner(BaseRunner):
...
@@ -59,11 +60,10 @@ class DefaultRunner(BaseRunner):
self
.
model
.
compile
(
self
.
config
.
get
(
"compile_shapes"
,
[]))
self
.
model
.
compile
(
self
.
config
.
get
(
"compile_shapes"
,
[]))
def
set_init_device
(
self
):
def
set_init_device
(
self
):
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
self
.
init_device
=
torch
.
device
(
"cpu"
)
self
.
init_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
self
.
init_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
self
.
init_device
=
torch
.
device
(
AI_DEVICE
)
def
load_vfi_model
(
self
):
def
load_vfi_model
(
self
):
if
self
.
config
[
"video_frame_interpolation"
].
get
(
"algo"
,
None
)
==
"rife"
:
if
self
.
config
[
"video_frame_interpolation"
].
get
(
"algo"
,
None
)
==
"rife"
:
...
...
lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
View file @
b50498fa
...
@@ -21,6 +21,7 @@ from lightx2v.server.metrics import monitor_cli
...
@@ -21,6 +21,7 @@ from lightx2v.server.metrics import monitor_cli
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
@
RUNNER_REGISTER
(
"hunyuan_video_1.5"
)
@
RUNNER_REGISTER
(
"hunyuan_video_1.5"
)
...
@@ -71,7 +72,7 @@ class HunyuanVideo15Runner(DefaultRunner):
...
@@ -71,7 +72,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if
qwen25vl_offload
:
if
qwen25vl_offload
:
qwen25vl_device
=
torch
.
device
(
"cpu"
)
qwen25vl_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
qwen25vl_device
=
torch
.
device
(
self
.
run_device
)
qwen25vl_device
=
torch
.
device
(
AI_DEVICE
)
qwen25vl_quantized
=
self
.
config
.
get
(
"qwen25vl_quantized"
,
False
)
qwen25vl_quantized
=
self
.
config
.
get
(
"qwen25vl_quantized"
,
False
)
qwen25vl_quant_scheme
=
self
.
config
.
get
(
"qwen25vl_quant_scheme"
,
None
)
qwen25vl_quant_scheme
=
self
.
config
.
get
(
"qwen25vl_quant_scheme"
,
None
)
...
@@ -82,7 +83,6 @@ class HunyuanVideo15Runner(DefaultRunner):
...
@@ -82,7 +83,6 @@ class HunyuanVideo15Runner(DefaultRunner):
text_encoder
=
Qwen25VL_TextEncoder
(
text_encoder
=
Qwen25VL_TextEncoder
(
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
device
=
qwen25vl_device
,
device
=
qwen25vl_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
text_encoder_path
,
checkpoint_path
=
text_encoder_path
,
cpu_offload
=
qwen25vl_offload
,
cpu_offload
=
qwen25vl_offload
,
qwen25vl_quantized
=
qwen25vl_quantized
,
qwen25vl_quantized
=
qwen25vl_quantized
,
...
@@ -94,9 +94,9 @@ class HunyuanVideo15Runner(DefaultRunner):
...
@@ -94,9 +94,9 @@ class HunyuanVideo15Runner(DefaultRunner):
if
byt5_offload
:
if
byt5_offload
:
byt5_device
=
torch
.
device
(
"cpu"
)
byt5_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
byt5_device
=
torch
.
device
(
self
.
run_device
)
byt5_device
=
torch
.
device
(
AI_DEVICE
)
byt5
=
ByT5TextEncoder
(
config
=
self
.
config
,
device
=
byt5_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
self
.
config
[
"model_path"
],
cpu_offload
=
byt5_offload
)
byt5
=
ByT5TextEncoder
(
config
=
self
.
config
,
device
=
byt5_device
,
checkpoint_path
=
self
.
config
[
"model_path"
],
cpu_offload
=
byt5_offload
)
text_encoders
=
[
text_encoder
,
byt5
]
text_encoders
=
[
text_encoder
,
byt5
]
return
text_encoders
return
text_encoders
...
@@ -230,11 +230,10 @@ class HunyuanVideo15Runner(DefaultRunner):
...
@@ -230,11 +230,10 @@ class HunyuanVideo15Runner(DefaultRunner):
if
siglip_offload
:
if
siglip_offload
:
siglip_device
=
torch
.
device
(
"cpu"
)
siglip_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
siglip_device
=
torch
.
device
(
self
.
run_device
)
siglip_device
=
torch
.
device
(
AI_DEVICE
)
image_encoder
=
SiglipVisionEncoder
(
image_encoder
=
SiglipVisionEncoder
(
config
=
self
.
config
,
config
=
self
.
config
,
device
=
siglip_device
,
device
=
siglip_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
self
.
config
[
"model_path"
],
checkpoint_path
=
self
.
config
[
"model_path"
],
cpu_offload
=
siglip_offload
,
cpu_offload
=
siglip_offload
,
)
)
...
@@ -246,7 +245,7 @@ class HunyuanVideo15Runner(DefaultRunner):
...
@@ -246,7 +245,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if
vae_offload
:
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
vae_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_device
=
torch
.
device
(
AI_DEVICE
)
vae_config
=
{
vae_config
=
{
"checkpoint_path"
:
self
.
config
[
"model_path"
],
"checkpoint_path"
:
self
.
config
[
"model_path"
],
...
@@ -265,7 +264,7 @@ class HunyuanVideo15Runner(DefaultRunner):
...
@@ -265,7 +264,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if
vae_offload
:
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
vae_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_device
=
torch
.
device
(
AI_DEVICE
)
vae_config
=
{
vae_config
=
{
"checkpoint_path"
:
self
.
config
[
"model_path"
],
"checkpoint_path"
:
self
.
config
[
"model_path"
],
...
@@ -275,7 +274,7 @@ class HunyuanVideo15Runner(DefaultRunner):
...
@@ -275,7 +274,7 @@ class HunyuanVideo15Runner(DefaultRunner):
}
}
if
self
.
config
.
get
(
"use_tae"
,
False
):
if
self
.
config
.
get
(
"use_tae"
,
False
):
tae_path
=
self
.
config
[
"tae_path"
]
tae_path
=
self
.
config
[
"tae_path"
]
vae_decoder
=
self
.
tae_cls
(
vae_path
=
tae_path
,
dtype
=
GET_DTYPE
()).
to
(
self
.
run_device
)
vae_decoder
=
self
.
tae_cls
(
vae_path
=
tae_path
,
dtype
=
GET_DTYPE
()).
to
(
AI_DEVICE
)
else
:
else
:
vae_decoder
=
self
.
vae_cls
(
**
vae_config
)
vae_decoder
=
self
.
vae_cls
(
**
vae_config
)
return
vae_decoder
return
vae_decoder
...
@@ -350,7 +349,7 @@ class HunyuanVideo15Runner(DefaultRunner):
...
@@ -350,7 +349,7 @@ class HunyuanVideo15Runner(DefaultRunner):
self
.
model_sr
.
scheduler
.
step_post
()
self
.
model_sr
.
scheduler
.
step_post
()
del
self
.
inputs_sr
del
self
.
inputs_sr
torch_ext_module
=
getattr
(
torch
,
self
.
run_device
)
torch_ext_module
=
getattr
(
torch
,
AI_DEVICE
)
torch_ext_module
.
empty_cache
()
torch_ext_module
.
empty_cache
()
self
.
config_sr
[
"is_sr_running"
]
=
False
self
.
config_sr
[
"is_sr_running"
]
=
False
...
@@ -369,10 +368,10 @@ class HunyuanVideo15Runner(DefaultRunner):
...
@@ -369,10 +368,10 @@ class HunyuanVideo15Runner(DefaultRunner):
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
# vision_states is all zero, because we don't have any image input
# vision_states is all zero, because we don't have any image input
siglip_output
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
self
.
config
[
"hidden_size"
],
dtype
=
torch
.
bfloat16
).
to
(
self
.
run_device
)
siglip_output
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
self
.
config
[
"hidden_size"
],
dtype
=
torch
.
bfloat16
).
to
(
AI_DEVICE
)
siglip_mask
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
self
.
run_device
))
siglip_mask
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
AI_DEVICE
))
torch_ext_module
=
getattr
(
torch
,
self
.
run_device
)
torch_ext_module
=
getattr
(
torch
,
AI_DEVICE
)
torch_ext_module
.
empty_cache
()
torch_ext_module
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
{
return
{
...
@@ -400,7 +399,7 @@ class HunyuanVideo15Runner(DefaultRunner):
...
@@ -400,7 +399,7 @@ class HunyuanVideo15Runner(DefaultRunner):
siglip_output
,
siglip_mask
=
self
.
run_image_encoder
(
img_ori
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
siglip_output
,
siglip_mask
=
self
.
run_image_encoder
(
img_ori
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
cond_latents
=
self
.
run_vae_encoder
(
img_ori
)
cond_latents
=
self
.
run_vae_encoder
(
img_ori
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
torch_ext_module
=
getattr
(
torch
,
self
.
run_device
)
torch_ext_module
=
getattr
(
torch
,
AI_DEVICE
)
torch_ext_module
.
empty_cache
()
torch_ext_module
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
{
return
{
...
@@ -427,9 +426,9 @@ class HunyuanVideo15Runner(DefaultRunner):
...
@@ -427,9 +426,9 @@ class HunyuanVideo15Runner(DefaultRunner):
target_height
=
self
.
target_height
target_height
=
self
.
target_height
input_image_np
=
self
.
resize_and_center_crop
(
first_frame
,
target_width
=
target_width
,
target_height
=
target_height
)
input_image_np
=
self
.
resize_and_center_crop
(
first_frame
,
target_width
=
target_width
,
target_height
=
target_height
)
vision_states
=
self
.
image_encoder
.
encode_images
(
input_image_np
).
last_hidden_state
.
to
(
device
=
torch
.
device
(
self
.
run_device
),
dtype
=
torch
.
bfloat16
)
vision_states
=
self
.
image_encoder
.
encode_images
(
input_image_np
).
last_hidden_state
.
to
(
device
=
torch
.
device
(
AI_DEVICE
),
dtype
=
torch
.
bfloat16
)
image_encoder_output
=
self
.
image_encoder
.
infer
(
vision_states
)
image_encoder_output
=
self
.
image_encoder
.
infer
(
vision_states
)
image_encoder_mask
=
torch
.
ones
((
1
,
image_encoder_output
.
shape
[
1
]),
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
self
.
run_device
))
image_encoder_mask
=
torch
.
ones
((
1
,
image_encoder_output
.
shape
[
1
]),
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
AI_DEVICE
))
return
image_encoder_output
,
image_encoder_mask
return
image_encoder_output
,
image_encoder_mask
def
resize_and_center_crop
(
self
,
image
,
target_width
,
target_height
):
def
resize_and_center_crop
(
self
,
image
,
target_width
,
target_height
):
...
@@ -480,6 +479,6 @@ class HunyuanVideo15Runner(DefaultRunner):
...
@@ -480,6 +479,6 @@ class HunyuanVideo15Runner(DefaultRunner):
]
]
)
)
ref_images_pixel_values
=
ref_image_transform
(
first_frame
).
unsqueeze
(
0
).
unsqueeze
(
2
).
to
(
self
.
run_device
)
ref_images_pixel_values
=
ref_image_transform
(
first_frame
).
unsqueeze
(
0
).
unsqueeze
(
2
).
to
(
AI_DEVICE
)
cond_latents
=
self
.
vae_encoder
.
encode
(
ref_images_pixel_values
.
to
(
GET_DTYPE
()))
cond_latents
=
self
.
vae_encoder
.
encode
(
ref_images_pixel_values
.
to
(
GET_DTYPE
()))
return
cond_latents
return
cond_latents
lightx2v/models/runners/qwen_image/qwen_image_runner.py
View file @
b50498fa
...
@@ -15,6 +15,9 @@ from lightx2v.server.metrics import monitor_cli
...
@@ -15,6 +15,9 @@ from lightx2v.server.metrics import monitor_cli
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
def
calculate_dimensions
(
target_area
,
ratio
):
def
calculate_dimensions
(
target_area
,
ratio
):
...
@@ -85,9 +88,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -85,9 +88,7 @@ class QwenImageRunner(DefaultRunner):
def
_run_input_encoder_local_t2i
(
self
):
def
_run_input_encoder_local_t2i
(
self
):
prompt
=
self
.
input_info
.
prompt
prompt
=
self
.
input_info
.
prompt
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
neg_prompt
=
self
.
input_info
.
negative_prompt
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
neg_prompt
=
self
.
input_info
.
negative_prompt
)
if
hasattr
(
torch
,
self
.
run_device
):
torch_device_module
.
empty_cache
()
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
{
return
{
"text_encoder_output"
:
text_encoder_output
,
"text_encoder_output"
:
text_encoder_output
,
...
@@ -102,7 +103,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -102,7 +103,7 @@ class QwenImageRunner(DefaultRunner):
if
GET_RECORDER_MODE
():
if
GET_RECORDER_MODE
():
width
,
height
=
img_ori
.
size
width
,
height
=
img_ori
.
size
monitor_cli
.
lightx2v_input_image_len
.
observe
(
width
*
height
)
monitor_cli
.
lightx2v_input_image_len
.
observe
(
width
*
height
)
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
to
(
self
.
run_device
)
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
to
(
AI_DEVICE
)
self
.
input_info
.
original_size
.
append
(
img_ori
.
size
)
self
.
input_info
.
original_size
.
append
(
img_ori
.
size
)
return
img
,
img_ori
return
img
,
img_ori
...
@@ -121,9 +122,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -121,9 +122,7 @@ class QwenImageRunner(DefaultRunner):
for
vae_image
in
text_encoder_output
[
"image_info"
][
"vae_image_list"
]:
for
vae_image
in
text_encoder_output
[
"image_info"
][
"vae_image_list"
]:
image_encoder_output
=
self
.
run_vae_encoder
(
image
=
vae_image
)
image_encoder_output
=
self
.
run_vae_encoder
(
image
=
vae_image
)
image_encoder_output_list
.
append
(
image_encoder_output
)
image_encoder_output_list
.
append
(
image_encoder_output
)
if
hasattr
(
torch
,
self
.
run_device
):
torch_device_module
.
empty_cache
()
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
{
return
{
"text_encoder_output"
:
text_encoder_output
,
"text_encoder_output"
:
text_encoder_output
,
...
@@ -238,9 +237,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -238,9 +237,7 @@ class QwenImageRunner(DefaultRunner):
images
=
self
.
vae
.
decode
(
latents
,
self
.
input_info
)
images
=
self
.
vae
.
decode
(
latents
,
self
.
input_info
)
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
vae_decoder
del
self
.
vae_decoder
if
hasattr
(
torch
,
self
.
run_device
):
torch_device_module
.
empty_cache
()
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
images
return
images
...
@@ -259,9 +256,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -259,9 +256,7 @@ class QwenImageRunner(DefaultRunner):
image
.
save
(
f
"
{
input_info
.
save_result_path
}
"
)
image
.
save
(
f
"
{
input_info
.
save_result_path
}
"
)
del
latents
,
generator
del
latents
,
generator
if
hasattr
(
torch
,
self
.
run_device
):
torch_device_module
.
empty_cache
()
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
# Return (images, audio) - audio is None for default runner
# Return (images, audio) - audio is None for default runner
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
b50498fa
...
@@ -33,6 +33,7 @@ from lightx2v.utils.envs import *
...
@@ -33,6 +33,7 @@ from lightx2v.utils.envs import *
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
find_torch_model_path
,
load_weights
,
vae_to_comfyui_image_inplace
from
lightx2v.utils.utils
import
find_torch_model_path
,
load_weights
,
vae_to_comfyui_image_inplace
from
lightx2v_platform.base.global_var
import
AI_DEVICE
warnings
.
filterwarnings
(
"ignore"
,
category
=
UserWarning
,
module
=
"torchaudio"
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
UserWarning
,
module
=
"torchaudio"
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
UserWarning
,
module
=
"torchvision.io"
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
UserWarning
,
module
=
"torchvision.io"
)
...
@@ -450,7 +451,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -450,7 +451,7 @@ class WanAudioRunner(WanRunner): # type:ignore
ref_img
=
img_path
ref_img
=
img_path
else
:
else
:
ref_img
=
load_image
(
img_path
)
ref_img
=
load_image
(
img_path
)
ref_img
=
TF
.
to_tensor
(
ref_img
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
to
(
self
.
run_device
)
ref_img
=
TF
.
to_tensor
(
ref_img
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
to
(
AI_DEVICE
)
ref_img
,
h
,
w
=
resize_image
(
ref_img
,
h
,
w
=
resize_image
(
ref_img
,
ref_img
,
...
@@ -538,15 +539,14 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -538,15 +539,14 @@ class WanAudioRunner(WanRunner): # type:ignore
def
prepare_prev_latents
(
self
,
prev_video
:
Optional
[
torch
.
Tensor
],
prev_frame_length
:
int
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
def
prepare_prev_latents
(
self
,
prev_video
:
Optional
[
torch
.
Tensor
],
prev_frame_length
:
int
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
"""Prepare previous latents for conditioning"""
"""Prepare previous latents for conditioning"""
device
=
self
.
run_device
dtype
=
GET_DTYPE
()
dtype
=
GET_DTYPE
()
tgt_h
,
tgt_w
=
self
.
input_info
.
target_shape
[
0
],
self
.
input_info
.
target_shape
[
1
]
tgt_h
,
tgt_w
=
self
.
input_info
.
target_shape
[
0
],
self
.
input_info
.
target_shape
[
1
]
prev_frames
=
torch
.
zeros
((
1
,
3
,
self
.
config
[
"target_video_length"
],
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
=
torch
.
zeros
((
1
,
3
,
self
.
config
[
"target_video_length"
],
tgt_h
,
tgt_w
),
device
=
AI_DEVICE
)
if
prev_video
is
not
None
:
if
prev_video
is
not
None
:
# Extract and process last frames
# Extract and process last frames
last_frames
=
prev_video
[:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
prev_video
[:,
:,
-
prev_frame_length
:].
clone
().
to
(
AI_DEVICE
)
if
self
.
config
[
"model_cls"
]
!=
"wan2.2_audio"
:
if
self
.
config
[
"model_cls"
]
!=
"wan2.2_audio"
:
last_frames
=
self
.
frame_preprocessor
.
process_prev_frames
(
last_frames
)
last_frames
=
self
.
frame_preprocessor
.
process_prev_frames
(
last_frames
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
...
@@ -574,7 +574,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -574,7 +574,7 @@ class WanAudioRunner(WanRunner): # type:ignore
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
dtype
))
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
dtype
))
frames_n
=
(
nframe
-
1
)
*
4
+
1
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
AI_DEVICE
,
dtype
=
dtype
)
prev_frame_len
=
max
((
prev_len
-
1
)
*
4
+
1
,
0
)
prev_frame_len
=
max
((
prev_len
-
1
)
*
4
+
1
,
0
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
self
.
_wan_mask_rearrange
(
prev_mask
)
prev_mask
=
self
.
_wan_mask_rearrange
(
prev_mask
)
...
@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def
load_audio_encoder
(
self
):
def
load_audio_encoder
(
self
):
audio_encoder_path
=
self
.
config
.
get
(
"audio_encoder_path"
,
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"TencentGameMate-chinese-hubert-large"
))
audio_encoder_path
=
self
.
config
.
get
(
"audio_encoder_path"
,
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"TencentGameMate-chinese-hubert-large"
))
audio_encoder_offload
=
self
.
config
.
get
(
"audio_encoder_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
,
False
))
audio_encoder_offload
=
self
.
config
.
get
(
"audio_encoder_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
,
False
))
model
=
SekoAudioEncoderModel
(
audio_encoder_path
,
self
.
config
[
"audio_sr"
],
audio_encoder_offload
,
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
model
=
SekoAudioEncoderModel
(
audio_encoder_path
,
self
.
config
[
"audio_sr"
],
audio_encoder_offload
)
return
model
return
model
def
load_audio_adapter
(
self
):
def
load_audio_adapter
(
self
):
...
@@ -843,7 +843,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -843,7 +843,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if
audio_adapter_offload
:
if
audio_adapter_offload
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
else
:
else
:
device
=
torch
.
device
(
self
.
run_device
)
device
=
torch
.
device
(
AI_DEVICE
)
audio_adapter
=
AudioAdapter
(
audio_adapter
=
AudioAdapter
(
attention_head_dim
=
self
.
config
[
"dim"
]
//
self
.
config
[
"num_heads"
],
attention_head_dim
=
self
.
config
[
"dim"
]
//
self
.
config
[
"num_heads"
],
num_attention_heads
=
self
.
config
[
"num_heads"
],
num_attention_heads
=
self
.
config
[
"num_heads"
],
...
@@ -856,7 +856,6 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -856,7 +856,6 @@ class WanAudioRunner(WanRunner): # type:ignore
quantized
=
self
.
config
.
get
(
"adapter_quantized"
,
False
),
quantized
=
self
.
config
.
get
(
"adapter_quantized"
,
False
),
quant_scheme
=
self
.
config
.
get
(
"adapter_quant_scheme"
,
None
),
quant_scheme
=
self
.
config
.
get
(
"adapter_quant_scheme"
,
None
),
cpu_offload
=
audio_adapter_offload
,
cpu_offload
=
audio_adapter_offload
,
run_device
=
self
.
run_device
,
)
)
audio_adapter
.
to
(
device
)
audio_adapter
.
to
(
device
)
...
@@ -892,11 +891,10 @@ class Wan22AudioRunner(WanAudioRunner):
...
@@ -892,11 +891,10 @@ class Wan22AudioRunner(WanAudioRunner):
if
vae_offload
:
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
vae_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_device
=
torch
.
device
(
AI_DEVICE
)
vae_config
=
{
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
"Wan2.2_VAE.pth"
),
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
"Wan2.2_VAE.pth"
),
"device"
:
vae_device
,
"device"
:
vae_device
,
"run_device"
:
self
.
run_device
,
"cpu_offload"
:
vae_offload
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
}
}
...
@@ -909,11 +907,10 @@ class Wan22AudioRunner(WanAudioRunner):
...
@@ -909,11 +907,10 @@ class Wan22AudioRunner(WanAudioRunner):
if
vae_offload
:
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
vae_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_device
=
torch
.
device
(
AI_DEVICE
)
vae_config
=
{
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
"Wan2.2_VAE.pth"
),
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
"Wan2.2_VAE.pth"
),
"device"
:
vae_device
,
"device"
:
vae_device
,
"run_device"
:
self
.
run_device
,
"cpu_offload"
:
vae_offload
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
}
}
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
b50498fa
...
@@ -29,6 +29,7 @@ from lightx2v.utils.envs import *
...
@@ -29,6 +29,7 @@ from lightx2v.utils.envs import *
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
@
RUNNER_REGISTER
(
"wan2.1"
)
@
RUNNER_REGISTER
(
"wan2.1"
)
...
@@ -65,7 +66,7 @@ class WanRunner(DefaultRunner):
...
@@ -65,7 +66,7 @@ class WanRunner(DefaultRunner):
if
clip_offload
:
if
clip_offload
:
clip_device
=
torch
.
device
(
"cpu"
)
clip_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
clip_device
=
torch
.
device
(
self
.
run_device
)
clip_device
=
torch
.
device
(
AI_DEVICE
)
# quant_config
# quant_config
clip_quantized
=
self
.
config
.
get
(
"clip_quantized"
,
False
)
clip_quantized
=
self
.
config
.
get
(
"clip_quantized"
,
False
)
if
clip_quantized
:
if
clip_quantized
:
...
@@ -84,7 +85,6 @@ class WanRunner(DefaultRunner):
...
@@ -84,7 +85,6 @@ class WanRunner(DefaultRunner):
image_encoder
=
CLIPModel
(
image_encoder
=
CLIPModel
(
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
device
=
clip_device
,
device
=
clip_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
clip_original_ckpt
,
checkpoint_path
=
clip_original_ckpt
,
clip_quantized
=
clip_quantized
,
clip_quantized
=
clip_quantized
,
clip_quantized_ckpt
=
clip_quantized_ckpt
,
clip_quantized_ckpt
=
clip_quantized_ckpt
,
...
@@ -102,7 +102,7 @@ class WanRunner(DefaultRunner):
...
@@ -102,7 +102,7 @@ class WanRunner(DefaultRunner):
if
t5_offload
:
if
t5_offload
:
t5_device
=
torch
.
device
(
"cpu"
)
t5_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
t5_device
=
torch
.
device
(
self
.
run_device
)
t5_device
=
torch
.
device
(
AI_DEVICE
)
tokenizer_path
=
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"google/umt5-xxl"
)
tokenizer_path
=
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"google/umt5-xxl"
)
# quant_config
# quant_config
t5_quantized
=
self
.
config
.
get
(
"t5_quantized"
,
False
)
t5_quantized
=
self
.
config
.
get
(
"t5_quantized"
,
False
)
...
@@ -123,7 +123,6 @@ class WanRunner(DefaultRunner):
...
@@ -123,7 +123,6 @@ class WanRunner(DefaultRunner):
text_len
=
self
.
config
[
"text_len"
],
text_len
=
self
.
config
[
"text_len"
],
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
device
=
t5_device
,
device
=
t5_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
t5_original_ckpt
,
checkpoint_path
=
t5_original_ckpt
,
tokenizer_path
=
tokenizer_path
,
tokenizer_path
=
tokenizer_path
,
shard_fn
=
None
,
shard_fn
=
None
,
...
@@ -142,12 +141,11 @@ class WanRunner(DefaultRunner):
...
@@ -142,12 +141,11 @@ class WanRunner(DefaultRunner):
if
vae_offload
:
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
vae_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_device
=
torch
.
device
(
AI_DEVICE
)
vae_config
=
{
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
self
.
vae_name
),
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
self
.
vae_name
),
"device"
:
vae_device
,
"device"
:
vae_device
,
"run_device"
:
self
.
run_device
,
"parallel"
:
self
.
config
[
"parallel"
],
"parallel"
:
self
.
config
[
"parallel"
],
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"cpu_offload"
:
vae_offload
,
"cpu_offload"
:
vae_offload
,
...
@@ -171,7 +169,6 @@ class WanRunner(DefaultRunner):
...
@@ -171,7 +169,6 @@ class WanRunner(DefaultRunner):
vae_config
=
{
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
self
.
vae_name
),
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
self
.
vae_name
),
"device"
:
vae_device
,
"device"
:
vae_device
,
"run_device"
:
self
.
run_device
,
"parallel"
:
self
.
config
[
"parallel"
],
"parallel"
:
self
.
config
[
"parallel"
],
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"cpu_offload"
:
vae_offload
,
"cpu_offload"
:
vae_offload
,
...
@@ -321,7 +318,7 @@ class WanRunner(DefaultRunner):
...
@@ -321,7 +318,7 @@ class WanRunner(DefaultRunner):
self
.
config
[
"target_video_length"
],
self
.
config
[
"target_video_length"
],
lat_h
,
lat_h
,
lat_w
,
lat_w
,
device
=
torch
.
device
(
self
.
run_device
),
device
=
torch
.
device
(
AI_DEVICE
),
)
)
if
last_frame
is
not
None
:
if
last_frame
is
not
None
:
msk
[:,
1
:
-
1
]
=
0
msk
[:,
1
:
-
1
]
=
0
...
@@ -343,7 +340,7 @@ class WanRunner(DefaultRunner):
...
@@ -343,7 +340,7 @@ class WanRunner(DefaultRunner):
torch
.
nn
.
functional
.
interpolate
(
last_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
nn
.
functional
.
interpolate
(
last_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
],
],
dim
=
1
,
dim
=
1
,
).
to
(
self
.
run_device
)
).
to
(
AI_DEVICE
)
else
:
else
:
vae_input
=
torch
.
concat
(
vae_input
=
torch
.
concat
(
[
[
...
@@ -351,7 +348,7 @@ class WanRunner(DefaultRunner):
...
@@ -351,7 +348,7 @@ class WanRunner(DefaultRunner):
torch
.
zeros
(
3
,
self
.
config
[
"target_video_length"
]
-
1
,
h
,
w
),
torch
.
zeros
(
3
,
self
.
config
[
"target_video_length"
]
-
1
,
h
,
w
),
],
],
dim
=
1
,
dim
=
1
,
).
to
(
self
.
run_device
)
).
to
(
AI_DEVICE
)
vae_encoder_out
=
self
.
vae_encoder
.
encode
(
vae_input
.
unsqueeze
(
0
).
to
(
GET_DTYPE
()))
vae_encoder_out
=
self
.
vae_encoder
.
encode
(
vae_input
.
unsqueeze
(
0
).
to
(
GET_DTYPE
()))
...
@@ -534,7 +531,7 @@ class Wan22DenseRunner(WanRunner):
...
@@ -534,7 +531,7 @@ class Wan22DenseRunner(WanRunner):
assert
img
.
width
==
ow
and
img
.
height
==
oh
assert
img
.
width
==
ow
and
img
.
height
==
oh
# to tensor
# to tensor
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
to
(
self
.
run_device
).
unsqueeze
(
1
)
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
to
(
AI_DEVICE
).
unsqueeze
(
1
)
vae_encoder_out
=
self
.
get_vae_encoder_output
(
img
)
vae_encoder_out
=
self
.
get_vae_encoder_output
(
img
)
latent_w
,
latent_h
=
ow
//
self
.
config
[
"vae_stride"
][
2
],
oh
//
self
.
config
[
"vae_stride"
][
1
]
latent_w
,
latent_h
=
ow
//
self
.
config
[
"vae_stride"
][
2
],
oh
//
self
.
config
[
"vae_stride"
][
1
]
latent_shape
=
self
.
get_latent_shape_with_lat_hw
(
latent_h
,
latent_w
)
latent_shape
=
self
.
get_latent_shape_with_lat_hw
(
latent_h
,
latent_w
)
...
...
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
View file @
b50498fa
...
@@ -271,8 +271,7 @@ def get_1d_rotary_pos_embed(
...
@@ -271,8 +271,7 @@ def get_1d_rotary_pos_embed(
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
# [D/2]
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
# [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
run_device
=
kwds
.
get
(
"run_device"
,
"cuda"
)
freqs
=
torch
.
outer
(
pos
*
interpolation_factor
,
freqs
).
to
(
AI_DEVICE
)
# [S, D/2]
freqs
=
torch
.
outer
(
pos
*
interpolation_factor
,
freqs
).
to
(
run_device
)
# [S, D/2]
if
use_real
:
if
use_real
:
freqs_cos
=
freqs
.
cos
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
freqs_cos
=
freqs
.
cos
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
freqs_sin
=
freqs
.
sin
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
freqs_sin
=
freqs
.
sin
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
...
...
lightx2v/models/schedulers/hunyuan_video/scheduler.py
View file @
b50498fa
...
@@ -11,7 +11,6 @@ from .posemb_layers import get_nd_rotary_pos_embed
...
@@ -11,7 +11,6 @@ from .posemb_layers import get_nd_rotary_pos_embed
class
HunyuanVideo15Scheduler
(
BaseScheduler
):
class
HunyuanVideo15Scheduler
(
BaseScheduler
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
run_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
reverse
=
True
self
.
reverse
=
True
self
.
num_train_timesteps
=
1000
self
.
num_train_timesteps
=
1000
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
...
@@ -25,13 +24,13 @@ class HunyuanVideo15Scheduler(BaseScheduler):
...
@@ -25,13 +24,13 @@ class HunyuanVideo15Scheduler(BaseScheduler):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
bfloat16
)
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
bfloat16
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
AI_DEVICE
,
shift
=
self
.
sample_shift
)
self
.
multitask_mask
=
self
.
get_task_mask
(
self
.
config
[
"task"
],
latent_shape
[
-
3
])
self
.
multitask_mask
=
self
.
get_task_mask
(
self
.
config
[
"task"
],
latent_shape
[
-
3
])
self
.
cond_latents_concat
,
self
.
mask_concat
=
self
.
_prepare_cond_latents_and_mask
(
self
.
config
[
"task"
],
image_encoder_output
[
"cond_latents"
],
self
.
latents
,
self
.
multitask_mask
,
self
.
reorg_token
)
self
.
cond_latents_concat
,
self
.
mask_concat
=
self
.
_prepare_cond_latents_and_mask
(
self
.
config
[
"task"
],
image_encoder_output
[
"cond_latents"
],
self
.
latents
,
self
.
multitask_mask
,
self
.
reorg_token
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
3
]))
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
3
]))
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
bfloat16
):
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
bfloat16
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
AI_DEVICE
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
self
.
latents
=
torch
.
randn
(
1
,
1
,
latent_shape
[
0
],
latent_shape
[
0
],
...
@@ -39,7 +38,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
...
@@ -39,7 +38,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
latent_shape
[
2
],
latent_shape
[
2
],
latent_shape
[
3
],
latent_shape
[
3
],
dtype
=
dtype
,
dtype
=
dtype
,
device
=
self
.
run_device
,
device
=
AI_DEVICE
,
generator
=
self
.
generator
,
generator
=
self
.
generator
,
)
)
...
@@ -127,7 +126,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
...
@@ -127,7 +126,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
if
rope_dim_list
is
None
:
if
rope_dim_list
is
None
:
rope_dim_list
=
[
head_dim
//
target_ndim
for
_
in
range
(
target_ndim
)]
rope_dim_list
=
[
head_dim
//
target_ndim
for
_
in
range
(
target_ndim
)]
assert
sum
(
rope_dim_list
)
==
head_dim
,
"sum(rope_dim_list) should equal to head_dim of attention layer"
assert
sum
(
rope_dim_list
)
==
head_dim
,
"sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos
,
freqs_sin
=
get_nd_rotary_pos_embed
(
rope_dim_list
,
rope_sizes
,
theta
=
self
.
config
[
"rope_theta"
],
use_real
=
True
,
theta_rescale_factor
=
1
,
device
=
self
.
run_device
)
freqs_cos
,
freqs_sin
=
get_nd_rotary_pos_embed
(
rope_dim_list
,
rope_sizes
,
theta
=
self
.
config
[
"rope_theta"
],
use_real
=
True
,
theta_rescale_factor
=
1
,
device
=
AI_DEVICE
)
cos_half
=
freqs_cos
[:,
::
2
].
contiguous
()
cos_half
=
freqs_cos
[:,
::
2
].
contiguous
()
sin_half
=
freqs_sin
[:,
::
2
].
contiguous
()
sin_half
=
freqs_sin
[:,
::
2
].
contiguous
()
cos_sin
=
torch
.
cat
([
cos_half
,
sin_half
],
dim
=-
1
)
cos_sin
=
torch
.
cat
([
cos_half
,
sin_half
],
dim
=-
1
)
...
@@ -149,9 +148,8 @@ class HunyuanVideo15SRScheduler(HunyuanVideo15Scheduler):
...
@@ -149,9 +148,8 @@ class HunyuanVideo15SRScheduler(HunyuanVideo15Scheduler):
def
prepare
(
self
,
seed
,
latent_shape
,
lq_latents
,
upsampler
,
image_encoder_output
=
None
):
def
prepare
(
self
,
seed
,
latent_shape
,
lq_latents
,
upsampler
,
image_encoder_output
=
None
):
dtype
=
lq_latents
.
dtype
dtype
=
lq_latents
.
dtype
device
=
lq_latents
.
device
self
.
prepare_latents
(
seed
,
latent_shape
,
lq_latents
,
dtype
=
dtype
)
self
.
prepare_latents
(
seed
,
latent_shape
,
lq_latents
,
dtype
=
dtype
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
AI_DEVICE
,
shift
=
self
.
sample_shift
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
3
]))
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
3
]))
tgt_shape
=
latent_shape
[
-
2
:]
tgt_shape
=
latent_shape
[
-
2
:]
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment