Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
b50498fa
Unverified
Commit
b50498fa
authored
Dec 02, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Dec 02, 2025
Browse files
Add lightx2v_platform (#541)
parent
31da6925
Changes
75
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
90 additions
and
105 deletions
+90
-105
lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
...tx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
+6
-6
lightx2v/models/input_encoders/hf/wan/t5/model.py
lightx2v/models/input_encoders/hf/wan/t5/model.py
+4
-5
lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
+5
-4
lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
+2
-2
lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
.../models/networks/hunyuan_video/infer/transformer_infer.py
+2
-2
lightx2v/models/networks/hunyuan_video/model.py
lightx2v/models/networks/hunyuan_video/model.py
+4
-4
lightx2v/models/networks/qwen_image/infer/transformer_infer.py
...x2v/models/networks/qwen_image/infer/transformer_infer.py
+1
-1
lightx2v/models/networks/qwen_image/model.py
lightx2v/models/networks/qwen_image/model.py
+6
-5
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+2
-2
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+0
-2
lightx2v/models/networks/wan/infer/triton_ops.py
lightx2v/models/networks/wan/infer/triton_ops.py
+1
-1
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+4
-5
lightx2v/models/runners/base_runner.py
lightx2v/models/runners/base_runner.py
+4
-2
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+2
-2
lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
...v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
+16
-17
lightx2v/models/runners/qwen_image/qwen_image_runner.py
lightx2v/models/runners/qwen_image/qwen_image_runner.py
+8
-13
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+9
-12
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+8
-11
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
+1
-2
lightx2v/models/schedulers/hunyuan_video/scheduler.py
lightx2v/models/schedulers/hunyuan_video/scheduler.py
+5
-7
No files found.
lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
View file @
b50498fa
...
...
@@ -2,18 +2,18 @@ import torch
from
transformers
import
AutoFeatureExtractor
,
AutoModel
from
lightx2v.utils.envs
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
SekoAudioEncoderModel
:
def
__init__
(
self
,
model_path
,
audio_sr
,
cpu_offload
,
run_device
):
def
__init__
(
self
,
model_path
,
audio_sr
,
cpu_offload
):
self
.
model_path
=
model_path
self
.
audio_sr
=
audio_sr
self
.
cpu_offload
=
cpu_offload
if
self
.
cpu_offload
:
self
.
device
=
torch
.
device
(
"cpu"
)
else
:
self
.
device
=
torch
.
device
(
run_device
)
self
.
run_device
=
run_device
self
.
device
=
torch
.
device
(
AI_DEVICE
)
self
.
load
()
def
load
(
self
):
...
...
@@ -27,13 +27,13 @@ class SekoAudioEncoderModel:
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
"cpu"
)
def
to_cuda
(
self
):
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
self
.
run_device
)
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
AI_DEVICE
)
@
torch
.
no_grad
()
def
infer
(
self
,
audio_segment
):
audio_feat
=
self
.
audio_feature_extractor
(
audio_segment
,
sampling_rate
=
self
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
to
(
self
.
run_device
).
to
(
dtype
=
GET_DTYPE
())
audio_feat
=
self
.
audio_feature_extractor
(
audio_segment
,
sampling_rate
=
self
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
to
(
AI_DEVICE
).
to
(
dtype
=
GET_DTYPE
())
if
self
.
cpu_offload
:
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
self
.
run_device
)
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
AI_DEVICE
)
audio_feat
=
self
.
audio_feature_encoder
(
audio_feat
,
return_dict
=
True
).
last_hidden_state
if
self
.
cpu_offload
:
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
"cpu"
)
...
...
lightx2v/models/input_encoders/hf/wan/t5/model.py
View file @
b50498fa
...
...
@@ -24,8 +24,8 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
SglQuantLinearFp8
,
# noqa E402
TorchaoQuantLinearInt8
,
# noqa E402
VllmQuantLinearInt8
,
# noqa E402,
MluQuantLinearInt8
,
)
from
lightx2v_platform.ops.mm.cambricon_mlu.q_linear
import
MluQuantLinearInt8
# noqa E402
from
lightx2v.models.input_encoders.hf.wan.t5.tokenizer
import
HuggingfaceTokenizer
# noqa E402
from
lightx2v.utils.envs
import
*
# noqa E402
from
lightx2v.utils.registry_factory
import
(
# noqa E402
...
...
@@ -34,6 +34,7 @@ from lightx2v.utils.registry_factory import ( # noqa E402
RMS_WEIGHT_REGISTER
,
# noqa E402
)
from
lightx2v.utils.utils
import
load_weights
# noqa E402
from
lightx2v_platform.base.global_var
import
AI_DEVICE
# noqa E402
__all__
=
[
"T5Model"
,
...
...
@@ -745,7 +746,6 @@ class T5EncoderModel:
text_len
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
tokenizer_path
=
None
,
shard_fn
=
None
,
...
...
@@ -758,7 +758,6 @@ class T5EncoderModel:
self
.
text_len
=
text_len
self
.
dtype
=
dtype
self
.
device
=
device
self
.
run_device
=
run_device
if
t5_quantized_ckpt
is
not
None
and
t5_quantized
:
self
.
checkpoint_path
=
t5_quantized_ckpt
else
:
...
...
@@ -807,8 +806,8 @@ class T5EncoderModel:
def
infer
(
self
,
texts
):
ids
,
mask
=
self
.
tokenizer
(
texts
,
return_mask
=
True
,
add_special_tokens
=
True
)
ids
=
ids
.
to
(
self
.
run_device
)
mask
=
mask
.
to
(
self
.
run_device
)
ids
=
ids
.
to
(
AI_DEVICE
)
mask
=
mask
.
to
(
AI_DEVICE
)
seq_lens
=
mask
.
gt
(
0
).
sum
(
dim
=
1
).
long
()
with
torch
.
no_grad
():
...
...
lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
View file @
b50498fa
...
...
@@ -10,8 +10,10 @@ from loguru import logger
# from lightx2v.attentions import attention
from
lightx2v.common.ops.attn
import
TorchSDPAWeight
from
lightx2v.models.input_encoders.hf.q_linear
import
MluQuantLinearInt8
,
Q8FQuantLinearFp8
,
Q8FQuantLinearInt8
,
SglQuantLinearFp8
,
TorchaoQuantLinearInt8
,
VllmQuantLinearInt8
from
lightx2v.models.input_encoders.hf.q_linear
import
Q8FQuantLinearFp8
,
Q8FQuantLinearInt8
,
SglQuantLinearFp8
,
TorchaoQuantLinearInt8
,
VllmQuantLinearInt8
from
lightx2v.utils.utils
import
load_weights
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
lightx2v_platform.ops.mm.cambricon_mlu.q_linear
import
MluQuantLinearInt8
__all__
=
[
"XLMRobertaCLIP"
,
...
...
@@ -426,9 +428,8 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class
CLIPModel
:
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
cpu_offload
=
False
,
use_31_block
=
True
,
load_from_rank0
=
False
,
run_device
=
torch
.
device
(
"cuda"
)
):
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
cpu_offload
=
False
,
use_31_block
=
True
,
load_from_rank0
=
False
):
self
.
dtype
=
dtype
self
.
run_device
=
run_device
self
.
quantized
=
clip_quantized
self
.
cpu_offload
=
cpu_offload
self
.
use_31_block
=
use_31_block
...
...
@@ -462,7 +463,7 @@ class CLIPModel:
return
out
def
to_cuda
(
self
):
self
.
model
=
self
.
model
.
to
(
self
.
run_device
)
self
.
model
=
self
.
model
.
to
(
AI_DEVICE
)
def
to_cpu
(
self
):
self
.
model
=
self
.
model
.
cpu
()
lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
View file @
b50498fa
...
...
@@ -5,6 +5,7 @@ import torch
from
einops
import
rearrange
from
lightx2v.utils.envs
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.attn_no_pad
import
flash_attn_no_pad
,
flash_attn_no_pad_v3
,
sage_attn_no_pad_v2
from
.module_io
import
HunyuanVideo15InferModuleOutput
...
...
@@ -68,7 +69,6 @@ class HunyuanVideo15PreInfer:
self
.
heads_num
=
config
[
"heads_num"
]
self
.
frequency_embedding_size
=
256
self
.
max_period
=
10000
self
.
run_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
...
...
@@ -155,7 +155,7 @@ class HunyuanVideo15PreInfer:
byt5_txt
=
byt5_txt
+
weights
.
cond_type_embedding
.
apply
(
torch
.
ones_like
(
byt5_txt
[:,
:,
0
],
device
=
byt5_txt
.
device
,
dtype
=
torch
.
long
))
txt
,
text_mask
=
self
.
reorder_txt_token
(
byt5_txt
,
txt
,
byt5_text_mask
,
text_mask
,
zero_feat
=
True
)
siglip_output
=
siglip_output
+
weights
.
cond_type_embedding
.
apply
(
2
*
torch
.
ones_like
(
siglip_output
[:,
:,
0
],
dtype
=
torch
.
long
,
device
=
self
.
run_device
))
siglip_output
=
siglip_output
+
weights
.
cond_type_embedding
.
apply
(
2
*
torch
.
ones_like
(
siglip_output
[:,
:,
0
],
dtype
=
torch
.
long
,
device
=
AI_DEVICE
))
txt
,
text_mask
=
self
.
reorder_txt_token
(
siglip_output
,
txt
,
siglip_mask
,
text_mask
)
txt
=
txt
[:,
:
text_mask
.
sum
(),
:]
...
...
lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
View file @
b50498fa
...
...
@@ -10,6 +10,7 @@ except Exception as e:
apply_rope_with_cos_sin_cache_inplace
=
None
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.module_io
import
HunyuanVideo15ImgBranchOutput
,
HunyuanVideo15TxtBranchOutput
from
.triton_ops
import
fuse_scale_shift_kernel
...
...
@@ -100,7 +101,6 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
self
.
config
=
config
self
.
double_blocks_num
=
config
[
"mm_double_blocks_depth"
]
self
.
heads_num
=
config
[
"heads_num"
]
self
.
run_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
if
self
.
config
[
"seq_parallel"
]:
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
self
.
seq_p_fp8_comm
=
self
.
config
[
"parallel"
].
get
(
"seq_p_fp8_comm"
,
False
)
...
...
@@ -222,7 +222,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
key
=
torch
.
cat
([
img_k
,
txt_k
],
dim
=
1
)
value
=
torch
.
cat
([
img_v
,
txt_v
],
dim
=
1
)
seqlen
=
query
.
shape
[
1
]
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
seqlen
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
self
.
run_device
,
non_blocking
=
True
)
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
seqlen
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
AI_DEVICE
,
non_blocking
=
True
)
if
self
.
config
[
"seq_parallel"
]:
attn_out
=
weights
.
self_attention_parallel
.
apply
(
...
...
lightx2v/models/networks/hunyuan_video/model.py
View file @
b50498fa
...
...
@@ -176,12 +176,12 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
self
.
device
.
type
!=
"cpu"
and
dist
.
is_initialized
()
:
device
=
torch
.
device
(
"{}:{}"
.
format
(
self
.
device
.
type
,
dist
.
get_rank
()
))
if
self
.
config
[
"parallel"
]
:
device
=
dist
.
get_rank
()
else
:
device
=
self
.
device
device
=
str
(
self
.
device
)
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
str
(
device
)
)
as
f
:
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
device
)
as
f
:
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
for
key
in
f
.
keys
()
...
...
lightx2v/models/networks/qwen_image/infer/transformer_infer.py
View file @
b50498fa
...
...
@@ -111,7 +111,7 @@ def apply_attn(block_weight, hidden_states, encoder_hidden_states, image_rotary_
if
attn_type
==
"torch_sdpa"
:
joint_hidden_states
=
block_weight
.
attn
.
calculate
.
apply
(
q
=
joint_query
,
k
=
joint_key
,
v
=
joint_value
)
el
if
attn_type
in
[
"flash_attn3"
,
"sage_attn2"
,
"mlu_flash_attn"
,
"flash_attn2"
,
"mlu_sage_attn"
]
:
el
se
:
joint_query
=
joint_query
.
squeeze
(
0
)
joint_key
=
joint_key
.
squeeze
(
0
)
joint_value
=
joint_value
.
squeeze
(
0
)
...
...
lightx2v/models/networks/qwen_image/model.py
View file @
b50498fa
...
...
@@ -8,6 +8,7 @@ from safetensors import safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.infer.offload.transformer_infer
import
QwenImageOffloadTransformerInfer
from
.infer.post_infer
import
QwenImagePostInfer
...
...
@@ -28,7 +29,7 @@ class QwenImageTransformerModel:
self
.
model_path
=
os
.
path
.
join
(
config
[
"model_path"
],
"transformer"
)
self
.
cpu_offload
=
config
.
get
(
"cpu_offload"
,
False
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
device
=
torch
.
device
(
"cpu"
)
if
self
.
cpu_offload
else
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
self
.
device
=
torch
.
device
(
"cpu"
)
if
self
.
cpu_offload
else
torch
.
device
(
AI_DEVICE
)
with
open
(
os
.
path
.
join
(
config
[
"model_path"
],
"transformer"
,
"config.json"
),
"r"
)
as
f
:
transformer_config
=
json
.
load
(
f
)
...
...
@@ -124,12 +125,12 @@ class QwenImageTransformerModel:
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
self
.
device
.
type
in
[
"cuda"
,
"mlu"
,
"npu"
]
and
dist
.
is_initialized
()
:
device
=
torch
.
device
(
"{}:{}"
.
format
(
self
.
device
.
type
,
dist
.
get_rank
()
))
if
self
.
config
[
"parallel"
]
:
device
=
dist
.
get_rank
()
else
:
device
=
self
.
device
device
=
str
(
self
.
device
)
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
str
(
device
)
)
as
f
:
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
device
)
as
f
:
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
for
key
in
f
.
keys
()
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
b50498fa
...
...
@@ -13,6 +13,7 @@ from lightx2v.models.networks.wan.weights.audio.transformer_weights import WanAu
from
lightx2v.models.networks.wan.weights.post_weights
import
WanPostWeights
from
lightx2v.models.networks.wan.weights.pre_weights
import
WanPreWeights
from
lightx2v.utils.utils
import
load_weights
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanAudioModel
(
WanModel
):
...
...
@@ -22,7 +23,6 @@ class WanAudioModel(WanModel):
def
__init__
(
self
,
model_path
,
config
,
device
):
self
.
config
=
config
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
_load_adapter_ckpt
()
super
().
__init__
(
model_path
,
config
,
device
)
...
...
@@ -51,7 +51,7 @@ class WanAudioModel(WanModel):
if
not
adapter_offload
:
if
not
dist
.
is_initialized
()
or
not
load_from_rank0
:
for
key
in
self
.
adapter_weights_dict
:
self
.
adapter_weights_dict
[
key
]
=
self
.
adapter_weights_dict
[
key
].
to
(
torch
.
device
(
self
.
run_device
))
self
.
adapter_weights_dict
[
key
]
=
self
.
adapter_weights_dict
[
key
].
to
(
torch
.
device
(
AI_DEVICE
))
def
_init_infer_class
(
self
):
super
().
_init_infer_class
()
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
b50498fa
...
...
@@ -10,10 +10,8 @@ class WanPreInfer:
def
__init__
(
self
,
config
):
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
self
.
config
=
config
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
task
=
config
[
"task"
]
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
enable_dynamic_cfg
=
config
.
get
(
"enable_dynamic_cfg"
,
False
)
...
...
lightx2v/models/networks/wan/infer/triton_ops.py
View file @
b50498fa
...
...
@@ -124,7 +124,7 @@ def fuse_scale_shift_kernel(
block_l
:
int
=
128
,
block_c
:
int
=
128
,
):
assert
x
.
is_cuda
and
scale
.
is_cuda
#
assert x.is_cuda and scale.is_cuda
assert
x
.
is_contiguous
()
if
x
.
dim
()
==
2
:
x
=
x
.
unsqueeze
(
0
)
...
...
lightx2v/models/networks/wan/model.py
View file @
b50498fa
...
...
@@ -44,7 +44,6 @@ class WanModel(CompiledMethodsMixin):
super
().
__init__
()
self
.
model_path
=
model_path
self
.
config
=
config
self
.
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
model_type
=
model_type
...
...
@@ -147,12 +146,12 @@ class WanModel(CompiledMethodsMixin):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
(
self
.
device
.
type
==
"cuda"
or
self
.
device
.
type
==
"mlu"
)
and
dist
.
is_initialized
()
:
device
=
torch
.
device
(
"{}:{}"
.
format
(
self
.
device
.
type
,
dist
.
get_rank
()
))
if
self
.
config
[
"parallel"
]
:
device
=
dist
.
get_rank
()
else
:
device
=
self
.
device
device
=
str
(
self
.
device
)
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
str
(
device
)
)
as
f
:
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
device
)
as
f
:
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
for
key
in
f
.
keys
()
...
...
lightx2v/models/runners/base_runner.py
View file @
b50498fa
...
...
@@ -3,6 +3,8 @@ from abc import ABC
import
torch
import
torch.distributed
as
dist
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
BaseRunner
(
ABC
):
"""Abstract base class for all Runners
...
...
@@ -145,9 +147,9 @@ class BaseRunner(ABC):
if
world_size
>
1
:
if
rank
==
signal_rank
:
t
=
torch
.
tensor
([
stopped
],
dtype
=
torch
.
int32
).
to
(
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
t
=
torch
.
tensor
([
stopped
],
dtype
=
torch
.
int32
).
to
(
device
=
AI_DEVICE
)
else
:
t
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
t
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
AI_DEVICE
)
dist
.
broadcast
(
t
,
src
=
signal_rank
)
stopped
=
t
.
item
()
...
...
lightx2v/models/runners/default_runner.py
View file @
b50498fa
...
...
@@ -15,6 +15,7 @@ from lightx2v.utils.global_paras import CALIB
from
lightx2v.utils.memory_profiler
import
peak_memory_decorator
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.base_runner
import
BaseRunner
...
...
@@ -59,11 +60,10 @@ class DefaultRunner(BaseRunner):
self
.
model
.
compile
(
self
.
config
.
get
(
"compile_shapes"
,
[]))
def
set_init_device
(
self
):
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
if
self
.
config
[
"cpu_offload"
]:
self
.
init_device
=
torch
.
device
(
"cpu"
)
else
:
self
.
init_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
self
.
init_device
=
torch
.
device
(
AI_DEVICE
)
def
load_vfi_model
(
self
):
if
self
.
config
[
"video_frame_interpolation"
].
get
(
"algo"
,
None
)
==
"rife"
:
...
...
lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
View file @
b50498fa
...
...
@@ -21,6 +21,7 @@ from lightx2v.server.metrics import monitor_cli
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
@
RUNNER_REGISTER
(
"hunyuan_video_1.5"
)
...
...
@@ -71,7 +72,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if
qwen25vl_offload
:
qwen25vl_device
=
torch
.
device
(
"cpu"
)
else
:
qwen25vl_device
=
torch
.
device
(
self
.
run_device
)
qwen25vl_device
=
torch
.
device
(
AI_DEVICE
)
qwen25vl_quantized
=
self
.
config
.
get
(
"qwen25vl_quantized"
,
False
)
qwen25vl_quant_scheme
=
self
.
config
.
get
(
"qwen25vl_quant_scheme"
,
None
)
...
...
@@ -82,7 +83,6 @@ class HunyuanVideo15Runner(DefaultRunner):
text_encoder
=
Qwen25VL_TextEncoder
(
dtype
=
torch
.
float16
,
device
=
qwen25vl_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
text_encoder_path
,
cpu_offload
=
qwen25vl_offload
,
qwen25vl_quantized
=
qwen25vl_quantized
,
...
...
@@ -94,9 +94,9 @@ class HunyuanVideo15Runner(DefaultRunner):
if
byt5_offload
:
byt5_device
=
torch
.
device
(
"cpu"
)
else
:
byt5_device
=
torch
.
device
(
self
.
run_device
)
byt5_device
=
torch
.
device
(
AI_DEVICE
)
byt5
=
ByT5TextEncoder
(
config
=
self
.
config
,
device
=
byt5_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
self
.
config
[
"model_path"
],
cpu_offload
=
byt5_offload
)
byt5
=
ByT5TextEncoder
(
config
=
self
.
config
,
device
=
byt5_device
,
checkpoint_path
=
self
.
config
[
"model_path"
],
cpu_offload
=
byt5_offload
)
text_encoders
=
[
text_encoder
,
byt5
]
return
text_encoders
...
...
@@ -230,11 +230,10 @@ class HunyuanVideo15Runner(DefaultRunner):
if
siglip_offload
:
siglip_device
=
torch
.
device
(
"cpu"
)
else
:
siglip_device
=
torch
.
device
(
self
.
run_device
)
siglip_device
=
torch
.
device
(
AI_DEVICE
)
image_encoder
=
SiglipVisionEncoder
(
config
=
self
.
config
,
device
=
siglip_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
self
.
config
[
"model_path"
],
cpu_offload
=
siglip_offload
,
)
...
...
@@ -246,7 +245,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_device
=
torch
.
device
(
AI_DEVICE
)
vae_config
=
{
"checkpoint_path"
:
self
.
config
[
"model_path"
],
...
...
@@ -265,7 +264,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_device
=
torch
.
device
(
AI_DEVICE
)
vae_config
=
{
"checkpoint_path"
:
self
.
config
[
"model_path"
],
...
...
@@ -275,7 +274,7 @@ class HunyuanVideo15Runner(DefaultRunner):
}
if
self
.
config
.
get
(
"use_tae"
,
False
):
tae_path
=
self
.
config
[
"tae_path"
]
vae_decoder
=
self
.
tae_cls
(
vae_path
=
tae_path
,
dtype
=
GET_DTYPE
()).
to
(
self
.
run_device
)
vae_decoder
=
self
.
tae_cls
(
vae_path
=
tae_path
,
dtype
=
GET_DTYPE
()).
to
(
AI_DEVICE
)
else
:
vae_decoder
=
self
.
vae_cls
(
**
vae_config
)
return
vae_decoder
...
...
@@ -350,7 +349,7 @@ class HunyuanVideo15Runner(DefaultRunner):
self
.
model_sr
.
scheduler
.
step_post
()
del
self
.
inputs_sr
torch_ext_module
=
getattr
(
torch
,
self
.
run_device
)
torch_ext_module
=
getattr
(
torch
,
AI_DEVICE
)
torch_ext_module
.
empty_cache
()
self
.
config_sr
[
"is_sr_running"
]
=
False
...
...
@@ -369,10 +368,10 @@ class HunyuanVideo15Runner(DefaultRunner):
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
# vision_states is all zero, because we don't have any image input
siglip_output
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
self
.
config
[
"hidden_size"
],
dtype
=
torch
.
bfloat16
).
to
(
self
.
run_device
)
siglip_mask
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
self
.
run_device
))
siglip_output
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
self
.
config
[
"hidden_size"
],
dtype
=
torch
.
bfloat16
).
to
(
AI_DEVICE
)
siglip_mask
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
AI_DEVICE
))
torch_ext_module
=
getattr
(
torch
,
self
.
run_device
)
torch_ext_module
=
getattr
(
torch
,
AI_DEVICE
)
torch_ext_module
.
empty_cache
()
gc
.
collect
()
return
{
...
...
@@ -400,7 +399,7 @@ class HunyuanVideo15Runner(DefaultRunner):
siglip_output
,
siglip_mask
=
self
.
run_image_encoder
(
img_ori
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
cond_latents
=
self
.
run_vae_encoder
(
img_ori
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
torch_ext_module
=
getattr
(
torch
,
self
.
run_device
)
torch_ext_module
=
getattr
(
torch
,
AI_DEVICE
)
torch_ext_module
.
empty_cache
()
gc
.
collect
()
return
{
...
...
@@ -427,9 +426,9 @@ class HunyuanVideo15Runner(DefaultRunner):
target_height
=
self
.
target_height
input_image_np
=
self
.
resize_and_center_crop
(
first_frame
,
target_width
=
target_width
,
target_height
=
target_height
)
vision_states
=
self
.
image_encoder
.
encode_images
(
input_image_np
).
last_hidden_state
.
to
(
device
=
torch
.
device
(
self
.
run_device
),
dtype
=
torch
.
bfloat16
)
vision_states
=
self
.
image_encoder
.
encode_images
(
input_image_np
).
last_hidden_state
.
to
(
device
=
torch
.
device
(
AI_DEVICE
),
dtype
=
torch
.
bfloat16
)
image_encoder_output
=
self
.
image_encoder
.
infer
(
vision_states
)
image_encoder_mask
=
torch
.
ones
((
1
,
image_encoder_output
.
shape
[
1
]),
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
self
.
run_device
))
image_encoder_mask
=
torch
.
ones
((
1
,
image_encoder_output
.
shape
[
1
]),
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
AI_DEVICE
))
return
image_encoder_output
,
image_encoder_mask
def
resize_and_center_crop
(
self
,
image
,
target_width
,
target_height
):
...
...
@@ -480,6 +479,6 @@ class HunyuanVideo15Runner(DefaultRunner):
]
)
ref_images_pixel_values
=
ref_image_transform
(
first_frame
).
unsqueeze
(
0
).
unsqueeze
(
2
).
to
(
self
.
run_device
)
ref_images_pixel_values
=
ref_image_transform
(
first_frame
).
unsqueeze
(
0
).
unsqueeze
(
2
).
to
(
AI_DEVICE
)
cond_latents
=
self
.
vae_encoder
.
encode
(
ref_images_pixel_values
.
to
(
GET_DTYPE
()))
return
cond_latents
lightx2v/models/runners/qwen_image/qwen_image_runner.py
View file @
b50498fa
...
...
@@ -15,6 +15,9 @@ from lightx2v.server.metrics import monitor_cli
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
def
calculate_dimensions
(
target_area
,
ratio
):
...
...
@@ -85,9 +88,7 @@ class QwenImageRunner(DefaultRunner):
def
_run_input_encoder_local_t2i
(
self
):
prompt
=
self
.
input_info
.
prompt
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
neg_prompt
=
self
.
input_info
.
negative_prompt
)
if
hasattr
(
torch
,
self
.
run_device
):
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
torch_device_module
.
empty_cache
()
gc
.
collect
()
return
{
"text_encoder_output"
:
text_encoder_output
,
...
...
@@ -102,7 +103,7 @@ class QwenImageRunner(DefaultRunner):
if
GET_RECORDER_MODE
():
width
,
height
=
img_ori
.
size
monitor_cli
.
lightx2v_input_image_len
.
observe
(
width
*
height
)
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
to
(
self
.
run_device
)
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
to
(
AI_DEVICE
)
self
.
input_info
.
original_size
.
append
(
img_ori
.
size
)
return
img
,
img_ori
...
...
@@ -121,9 +122,7 @@ class QwenImageRunner(DefaultRunner):
for
vae_image
in
text_encoder_output
[
"image_info"
][
"vae_image_list"
]:
image_encoder_output
=
self
.
run_vae_encoder
(
image
=
vae_image
)
image_encoder_output_list
.
append
(
image_encoder_output
)
if
hasattr
(
torch
,
self
.
run_device
):
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
torch_device_module
.
empty_cache
()
gc
.
collect
()
return
{
"text_encoder_output"
:
text_encoder_output
,
...
...
@@ -238,9 +237,7 @@ class QwenImageRunner(DefaultRunner):
images
=
self
.
vae
.
decode
(
latents
,
self
.
input_info
)
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
vae_decoder
if
hasattr
(
torch
,
self
.
run_device
):
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
torch_device_module
.
empty_cache
()
gc
.
collect
()
return
images
...
...
@@ -259,9 +256,7 @@ class QwenImageRunner(DefaultRunner):
image
.
save
(
f
"
{
input_info
.
save_result_path
}
"
)
del
latents
,
generator
if
hasattr
(
torch
,
self
.
run_device
):
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
torch_device_module
.
empty_cache
()
gc
.
collect
()
# Return (images, audio) - audio is None for default runner
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
b50498fa
...
...
@@ -33,6 +33,7 @@ from lightx2v.utils.envs import *
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
find_torch_model_path
,
load_weights
,
vae_to_comfyui_image_inplace
from
lightx2v_platform.base.global_var
import
AI_DEVICE
warnings
.
filterwarnings
(
"ignore"
,
category
=
UserWarning
,
module
=
"torchaudio"
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
UserWarning
,
module
=
"torchvision.io"
)
...
...
@@ -450,7 +451,7 @@ class WanAudioRunner(WanRunner): # type:ignore
ref_img
=
img_path
else
:
ref_img
=
load_image
(
img_path
)
ref_img
=
TF
.
to_tensor
(
ref_img
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
to
(
self
.
run_device
)
ref_img
=
TF
.
to_tensor
(
ref_img
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
to
(
AI_DEVICE
)
ref_img
,
h
,
w
=
resize_image
(
ref_img
,
...
...
@@ -538,15 +539,14 @@ class WanAudioRunner(WanRunner): # type:ignore
def
prepare_prev_latents
(
self
,
prev_video
:
Optional
[
torch
.
Tensor
],
prev_frame_length
:
int
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
"""Prepare previous latents for conditioning"""
device
=
self
.
run_device
dtype
=
GET_DTYPE
()
tgt_h
,
tgt_w
=
self
.
input_info
.
target_shape
[
0
],
self
.
input_info
.
target_shape
[
1
]
prev_frames
=
torch
.
zeros
((
1
,
3
,
self
.
config
[
"target_video_length"
],
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
=
torch
.
zeros
((
1
,
3
,
self
.
config
[
"target_video_length"
],
tgt_h
,
tgt_w
),
device
=
AI_DEVICE
)
if
prev_video
is
not
None
:
# Extract and process last frames
last_frames
=
prev_video
[:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
prev_video
[:,
:,
-
prev_frame_length
:].
clone
().
to
(
AI_DEVICE
)
if
self
.
config
[
"model_cls"
]
!=
"wan2.2_audio"
:
last_frames
=
self
.
frame_preprocessor
.
process_prev_frames
(
last_frames
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
...
...
@@ -574,7 +574,7 @@ class WanAudioRunner(WanRunner): # type:ignore
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
dtype
))
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
AI_DEVICE
,
dtype
=
dtype
)
prev_frame_len
=
max
((
prev_len
-
1
)
*
4
+
1
,
0
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
self
.
_wan_mask_rearrange
(
prev_mask
)
...
...
@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def
load_audio_encoder
(
self
):
audio_encoder_path
=
self
.
config
.
get
(
"audio_encoder_path"
,
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"TencentGameMate-chinese-hubert-large"
))
audio_encoder_offload
=
self
.
config
.
get
(
"audio_encoder_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
,
False
))
model
=
SekoAudioEncoderModel
(
audio_encoder_path
,
self
.
config
[
"audio_sr"
],
audio_encoder_offload
,
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
model
=
SekoAudioEncoderModel
(
audio_encoder_path
,
self
.
config
[
"audio_sr"
],
audio_encoder_offload
)
return
model
def
load_audio_adapter
(
self
):
...
...
@@ -843,7 +843,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if
audio_adapter_offload
:
device
=
torch
.
device
(
"cpu"
)
else
:
device
=
torch
.
device
(
self
.
run_device
)
device
=
torch
.
device
(
AI_DEVICE
)
audio_adapter
=
AudioAdapter
(
attention_head_dim
=
self
.
config
[
"dim"
]
//
self
.
config
[
"num_heads"
],
num_attention_heads
=
self
.
config
[
"num_heads"
],
...
...
@@ -856,7 +856,6 @@ class WanAudioRunner(WanRunner): # type:ignore
quantized
=
self
.
config
.
get
(
"adapter_quantized"
,
False
),
quant_scheme
=
self
.
config
.
get
(
"adapter_quant_scheme"
,
None
),
cpu_offload
=
audio_adapter_offload
,
run_device
=
self
.
run_device
,
)
audio_adapter
.
to
(
device
)
...
...
@@ -892,11 +891,10 @@ class Wan22AudioRunner(WanAudioRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_device
=
torch
.
device
(
AI_DEVICE
)
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
"Wan2.2_VAE.pth"
),
"device"
:
vae_device
,
"run_device"
:
self
.
run_device
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
}
...
...
@@ -909,11 +907,10 @@ class Wan22AudioRunner(WanAudioRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_device
=
torch
.
device
(
AI_DEVICE
)
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
"Wan2.2_VAE.pth"
),
"device"
:
vae_device
,
"run_device"
:
self
.
run_device
,
"cpu_offload"
:
vae_offload
,
"offload_cache"
:
self
.
config
.
get
(
"vae_offload_cache"
,
False
),
}
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
b50498fa
...
...
@@ -29,6 +29,7 @@ from lightx2v.utils.envs import *
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
@
RUNNER_REGISTER
(
"wan2.1"
)
...
...
@@ -65,7 +66,7 @@ class WanRunner(DefaultRunner):
if
clip_offload
:
clip_device
=
torch
.
device
(
"cpu"
)
else
:
clip_device
=
torch
.
device
(
self
.
run_device
)
clip_device
=
torch
.
device
(
AI_DEVICE
)
# quant_config
clip_quantized
=
self
.
config
.
get
(
"clip_quantized"
,
False
)
if
clip_quantized
:
...
...
@@ -84,7 +85,6 @@ class WanRunner(DefaultRunner):
image_encoder
=
CLIPModel
(
dtype
=
torch
.
float16
,
device
=
clip_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
clip_original_ckpt
,
clip_quantized
=
clip_quantized
,
clip_quantized_ckpt
=
clip_quantized_ckpt
,
...
...
@@ -102,7 +102,7 @@ class WanRunner(DefaultRunner):
if
t5_offload
:
t5_device
=
torch
.
device
(
"cpu"
)
else
:
t5_device
=
torch
.
device
(
self
.
run_device
)
t5_device
=
torch
.
device
(
AI_DEVICE
)
tokenizer_path
=
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"google/umt5-xxl"
)
# quant_config
t5_quantized
=
self
.
config
.
get
(
"t5_quantized"
,
False
)
...
...
@@ -123,7 +123,6 @@ class WanRunner(DefaultRunner):
text_len
=
self
.
config
[
"text_len"
],
dtype
=
torch
.
bfloat16
,
device
=
t5_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
t5_original_ckpt
,
tokenizer_path
=
tokenizer_path
,
shard_fn
=
None
,
...
...
@@ -142,12 +141,11 @@ class WanRunner(DefaultRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_device
=
torch
.
device
(
AI_DEVICE
)
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
self
.
vae_name
),
"device"
:
vae_device
,
"run_device"
:
self
.
run_device
,
"parallel"
:
self
.
config
[
"parallel"
],
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"cpu_offload"
:
vae_offload
,
...
...
@@ -171,7 +169,6 @@ class WanRunner(DefaultRunner):
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
self
.
vae_name
),
"device"
:
vae_device
,
"run_device"
:
self
.
run_device
,
"parallel"
:
self
.
config
[
"parallel"
],
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"cpu_offload"
:
vae_offload
,
...
...
@@ -321,7 +318,7 @@ class WanRunner(DefaultRunner):
self
.
config
[
"target_video_length"
],
lat_h
,
lat_w
,
device
=
torch
.
device
(
self
.
run_device
),
device
=
torch
.
device
(
AI_DEVICE
),
)
if
last_frame
is
not
None
:
msk
[:,
1
:
-
1
]
=
0
...
...
@@ -343,7 +340,7 @@ class WanRunner(DefaultRunner):
torch
.
nn
.
functional
.
interpolate
(
last_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
],
dim
=
1
,
).
to
(
self
.
run_device
)
).
to
(
AI_DEVICE
)
else
:
vae_input
=
torch
.
concat
(
[
...
...
@@ -351,7 +348,7 @@ class WanRunner(DefaultRunner):
torch
.
zeros
(
3
,
self
.
config
[
"target_video_length"
]
-
1
,
h
,
w
),
],
dim
=
1
,
).
to
(
self
.
run_device
)
).
to
(
AI_DEVICE
)
vae_encoder_out
=
self
.
vae_encoder
.
encode
(
vae_input
.
unsqueeze
(
0
).
to
(
GET_DTYPE
()))
...
...
@@ -534,7 +531,7 @@ class Wan22DenseRunner(WanRunner):
assert
img
.
width
==
ow
and
img
.
height
==
oh
# to tensor
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
to
(
self
.
run_device
).
unsqueeze
(
1
)
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
to
(
AI_DEVICE
).
unsqueeze
(
1
)
vae_encoder_out
=
self
.
get_vae_encoder_output
(
img
)
latent_w
,
latent_h
=
ow
//
self
.
config
[
"vae_stride"
][
2
],
oh
//
self
.
config
[
"vae_stride"
][
1
]
latent_shape
=
self
.
get_latent_shape_with_lat_hw
(
latent_h
,
latent_w
)
...
...
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
View file @
b50498fa
...
...
@@ -271,8 +271,7 @@ def get_1d_rotary_pos_embed(
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
# [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
run_device
=
kwds
.
get
(
"run_device"
,
"cuda"
)
freqs
=
torch
.
outer
(
pos
*
interpolation_factor
,
freqs
).
to
(
run_device
)
# [S, D/2]
freqs
=
torch
.
outer
(
pos
*
interpolation_factor
,
freqs
).
to
(
AI_DEVICE
)
# [S, D/2]
if
use_real
:
freqs_cos
=
freqs
.
cos
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
freqs_sin
=
freqs
.
sin
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
...
...
lightx2v/models/schedulers/hunyuan_video/scheduler.py
View file @
b50498fa
...
...
@@ -11,7 +11,6 @@ from .posemb_layers import get_nd_rotary_pos_embed
class
HunyuanVideo15Scheduler
(
BaseScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
run_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
reverse
=
True
self
.
num_train_timesteps
=
1000
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
...
...
@@ -25,13 +24,13 @@ class HunyuanVideo15Scheduler(BaseScheduler):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
bfloat16
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
AI_DEVICE
,
shift
=
self
.
sample_shift
)
self
.
multitask_mask
=
self
.
get_task_mask
(
self
.
config
[
"task"
],
latent_shape
[
-
3
])
self
.
cond_latents_concat
,
self
.
mask_concat
=
self
.
_prepare_cond_latents_and_mask
(
self
.
config
[
"task"
],
image_encoder_output
[
"cond_latents"
],
self
.
latents
,
self
.
multitask_mask
,
self
.
reorg_token
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
3
]))
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
bfloat16
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
AI_DEVICE
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
1
,
latent_shape
[
0
],
...
...
@@ -39,7 +38,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
latent_shape
[
2
],
latent_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
run_device
,
device
=
AI_DEVICE
,
generator
=
self
.
generator
,
)
...
...
@@ -127,7 +126,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
if
rope_dim_list
is
None
:
rope_dim_list
=
[
head_dim
//
target_ndim
for
_
in
range
(
target_ndim
)]
assert
sum
(
rope_dim_list
)
==
head_dim
,
"sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos
,
freqs_sin
=
get_nd_rotary_pos_embed
(
rope_dim_list
,
rope_sizes
,
theta
=
self
.
config
[
"rope_theta"
],
use_real
=
True
,
theta_rescale_factor
=
1
,
device
=
self
.
run_device
)
freqs_cos
,
freqs_sin
=
get_nd_rotary_pos_embed
(
rope_dim_list
,
rope_sizes
,
theta
=
self
.
config
[
"rope_theta"
],
use_real
=
True
,
theta_rescale_factor
=
1
,
device
=
AI_DEVICE
)
cos_half
=
freqs_cos
[:,
::
2
].
contiguous
()
sin_half
=
freqs_sin
[:,
::
2
].
contiguous
()
cos_sin
=
torch
.
cat
([
cos_half
,
sin_half
],
dim
=-
1
)
...
...
@@ -149,9 +148,8 @@ class HunyuanVideo15SRScheduler(HunyuanVideo15Scheduler):
def
prepare
(
self
,
seed
,
latent_shape
,
lq_latents
,
upsampler
,
image_encoder_output
=
None
):
dtype
=
lq_latents
.
dtype
device
=
lq_latents
.
device
self
.
prepare_latents
(
seed
,
latent_shape
,
lq_latents
,
dtype
=
dtype
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
AI_DEVICE
,
shift
=
self
.
sample_shift
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
3
]))
tgt_shape
=
latent_shape
[
-
2
:]
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment