Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
4c0a9a0d
Unverified
Commit
4c0a9a0d
authored
Nov 27, 2025
by
Gu Shiqiao
Committed by
GitHub
Nov 27, 2025
Browse files
Fix device bugs (#527)
parent
fbb19ffc
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
142 additions
and
77 deletions
+142
-77
configs/seko_talk/mlu/seko_talk_bf16.json
configs/seko_talk/mlu/seko_talk_bf16.json
+19
-0
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+19
-0
lightx2v/infer.py
lightx2v/infer.py
+5
-1
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
+6
-5
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
+6
-5
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
+7
-2
lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
...tx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
+4
-4
lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
...tx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
+6
-5
lightx2v/models/input_encoders/hf/wan/t5/model.py
lightx2v/models/input_encoders/hf/wan/t5/model.py
+5
-3
lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
+0
-1
lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
+2
-2
lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
.../models/networks/hunyuan_video/infer/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+10
-2
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+5
-1
lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
...v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
+17
-16
lightx2v/models/runners/qwen_image/qwen_image_runner.py
lightx2v/models/runners/qwen_image/qwen_image_runner.py
+9
-9
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+5
-5
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+7
-6
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
+2
-2
lightx2v/models/schedulers/hunyuan_video/scheduler.py
lightx2v/models/schedulers/hunyuan_video/scheduler.py
+6
-6
No files found.
configs/seko_talk/mlu/seko_talk_bf16.json
0 → 100644
View file @
4c0a9a0d
{
"infer_steps"
:
4
,
"target_fps"
:
16
,
"video_duration"
:
360
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"resize_mode"
:
"adaptive"
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"sample_guide_scale"
:
1.0
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"use_31_block"
:
false
,
"run_device"
:
"mlu"
,
"rope_type"
:
"torch"
,
"modulate_type"
:
"torch"
}
lightx2v/common/ops/mm/mm_weight.py
View file @
4c0a9a0d
...
...
@@ -442,6 +442,25 @@ class MMWeightQuantTemplate(MMWeightTemplate):
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
if
self
.
bias_name
is
not
None
:
if
self
.
create_cuda_buffer
:
# move to cuda buffer
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
else
:
device
=
weight_dict
[
self
.
bias_name
].
device
if
device
.
type
==
"cuda"
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
elif
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
def
load_fp8_perblock128_sym
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
self
.
weight
=
weight_dict
[
self
.
weight_name
]
...
...
lightx2v/infer.py
View file @
4c0a9a0d
...
...
@@ -3,7 +3,11 @@ import argparse
import
torch
import
torch.distributed
as
dist
from
loguru
import
logger
from
torch.distributed
import
ProcessGroupNCCL
try
:
from
torch.distributed
import
ProcessGroupNCCL
except
ImportError
:
ProcessGroupNCCL
=
None
from
lightx2v.common.ops
import
*
from
lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner
import
HunyuanVideo15DistillRunner
# noqa: F401
...
...
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
View file @
4c0a9a0d
...
...
@@ -159,13 +159,14 @@ class ByT5TextEncoder:
self
,
config
,
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
byt5_max_length
=
256
,
cpu_offload
=
False
,
):
self
.
cpu_offload
=
cpu_offload
self
.
config
=
config
self
.
device
=
device
self
.
run_
device
=
run_
device
self
.
byt5_max_length
=
byt5_max_length
self
.
enable_cfg
=
config
.
get
(
"enable_cfg"
,
False
)
byT5_google_path
=
os
.
path
.
join
(
checkpoint_path
,
"text_encoder"
,
"byt5-small"
)
...
...
@@ -300,12 +301,12 @@ class ByT5TextEncoder:
negative_masks
=
[]
for
prompt
in
prompt_list
:
pos_emb
,
pos_mask
=
self
.
_process_single_byt5_prompt
(
prompt
,
self
.
device
)
pos_emb
,
pos_mask
=
self
.
_process_single_byt5_prompt
(
prompt
,
self
.
run_
device
)
positive_embeddings
.
append
(
pos_emb
)
positive_masks
.
append
(
pos_mask
)
if
self
.
enable_cfg
:
# TODO: 把cfg拆出去,更适合并行
neg_emb
,
neg_mask
=
self
.
_process_single_byt5_prompt
(
""
,
self
.
device
)
neg_emb
,
neg_mask
=
self
.
_process_single_byt5_prompt
(
""
,
self
.
run_
device
)
negative_embeddings
.
append
(
neg_emb
)
negative_masks
.
append
(
neg_mask
)
...
...
@@ -327,8 +328,8 @@ class ByT5TextEncoder:
@
torch
.
no_grad
()
def
infer
(
self
,
prompts
):
if
self
.
cpu_offload
:
self
.
byt5_model
=
self
.
byt5_model
.
to
(
self
.
device
)
self
.
byt5_mapper
=
self
.
byt5_mapper
.
to
(
self
.
device
)
self
.
byt5_model
=
self
.
byt5_model
.
to
(
self
.
run_
device
)
self
.
byt5_mapper
=
self
.
byt5_mapper
.
to
(
self
.
run_
device
)
byt5_embeddings
,
byt5_masks
=
self
.
_prepare_byt5_embeddings
(
prompts
)
byt5_features
=
self
.
byt5_mapper
(
byt5_embeddings
.
to
(
torch
.
bfloat16
))
if
self
.
cpu_offload
:
...
...
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
View file @
4c0a9a0d
...
...
@@ -552,6 +552,7 @@ class Qwen25VL_TextEncoder:
text_len
=
1000
,
dtype
=
torch
.
float16
,
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
cpu_offload
=
False
,
qwen25vl_quantized
=
False
,
...
...
@@ -560,7 +561,7 @@ class Qwen25VL_TextEncoder:
):
self
.
text_len
=
text_len
self
.
dtype
=
dtype
self
.
device
=
device
self
.
run_
device
=
run_
device
self
.
cpu_offload
=
cpu_offload
self
.
qwen25vl_quantized
=
qwen25vl_quantized
self
.
qwen25vl_quant_scheme
=
qwen25vl_quant_scheme
...
...
@@ -589,20 +590,20 @@ class Qwen25VL_TextEncoder:
def
infer
(
self
,
texts
):
if
self
.
cpu_offload
:
self
.
text_encoder
=
self
.
text_encoder
.
to
(
self
.
device
)
self
.
text_encoder
=
self
.
text_encoder
.
to
(
self
.
run_
device
)
text_inputs
=
self
.
text_encoder
.
text2tokens
(
texts
,
data_type
=
"video"
,
max_length
=
self
.
text_len
)
prompt_outputs
=
self
.
text_encoder
.
encode
(
text_inputs
,
data_type
=
"video"
,
device
=
self
.
device
)
prompt_outputs
=
self
.
text_encoder
.
encode
(
text_inputs
,
data_type
=
"video"
,
device
=
self
.
run_
device
)
if
self
.
cpu_offload
:
self
.
text_encoder
=
self
.
text_encoder
.
to
(
"cpu"
)
prompt_embeds
=
prompt_outputs
.
hidden_state
attention_mask
=
prompt_outputs
.
attention_mask
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
to
(
self
.
device
)
attention_mask
=
attention_mask
.
to
(
self
.
run_
device
)
_
,
seq_len
=
attention_mask
.
shape
attention_mask
=
attention_mask
.
repeat
(
1
,
self
.
num_videos_per_prompt
)
attention_mask
=
attention_mask
.
view
(
self
.
num_videos_per_prompt
,
seq_len
)
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
run_
device
)
seq_len
=
prompt_embeds
.
shape
[
1
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
View file @
4c0a9a0d
...
...
@@ -95,6 +95,7 @@ class VisionEncoder(nn.Module):
output_key
:
Optional
[
str
]
=
None
,
logger
=
None
,
device
=
None
,
run_device
=
None
,
cpu_offload
=
False
,
):
super
().
__init__
()
...
...
@@ -120,6 +121,7 @@ class VisionEncoder(nn.Module):
)
self
.
dtype
=
self
.
model
.
dtype
self
.
device
=
self
.
model
.
device
self
.
run_device
=
run_device
self
.
processor
,
self
.
processor_path
=
load_image_processor
(
processor_type
=
self
.
processor_type
,
...
...
@@ -175,7 +177,7 @@ class VisionEncoder(nn.Module):
if
isinstance
(
images
,
np
.
ndarray
):
# Preprocess images if they're numpy arrays
preprocessed
=
self
.
processor
.
preprocess
(
images
=
images
,
return_tensors
=
"pt"
).
to
(
device
=
self
.
device
,
dtype
=
self
.
model
.
dtype
)
preprocessed
=
self
.
processor
.
preprocess
(
images
=
images
,
return_tensors
=
"pt"
).
to
(
device
=
self
.
run_
device
,
dtype
=
self
.
model
.
dtype
)
else
:
# Assume already preprocessed
preprocessed
=
images
...
...
@@ -230,11 +232,13 @@ class SiglipVisionEncoder:
self
,
config
,
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
cpu_offload
=
False
,
):
self
.
config
=
config
self
.
device
=
device
self
.
run_device
=
run_device
self
.
cpu_offload
=
cpu_offload
self
.
vision_states_dim
=
1152
vision_encoder_path
=
os
.
path
.
join
(
checkpoint_path
,
"vision_encoder"
,
"siglip"
)
...
...
@@ -248,6 +252,7 @@ class SiglipVisionEncoder:
output_key
=
None
,
logger
=
None
,
device
=
self
.
device
,
run_device
=
self
.
run_device
,
cpu_offload
=
self
.
cpu_offload
,
)
...
...
@@ -265,7 +270,7 @@ class SiglipVisionEncoder:
@
torch
.
no_grad
()
def
infer
(
self
,
vision_states
):
if
self
.
cpu_offload
:
self
.
vision_in
=
self
.
vision_in
.
to
(
"cuda"
)
self
.
vision_in
=
self
.
vision_in
.
to
(
self
.
run_device
)
vision_states
=
self
.
vision_in
(
vision_states
)
if
self
.
cpu_offload
:
self
.
vision_in
=
self
.
vision_in
.
to
(
"cpu"
)
...
...
lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
View file @
4c0a9a0d
...
...
@@ -252,7 +252,7 @@ class AudioAdapter(nn.Module):
quantized
:
bool
=
False
,
quant_scheme
:
str
=
None
,
cpu_offload
:
bool
=
False
,
device
=
torch
.
device
(
"c
p
u"
),
run_
device
=
torch
.
device
(
"cu
da
"
),
):
super
().
__init__
()
self
.
cpu_offload
=
cpu_offload
...
...
@@ -263,7 +263,7 @@ class AudioAdapter(nn.Module):
mlp_dims
=
mlp_dims
,
transformer_layers
=
projection_transformer_layers
,
)
self
.
device
=
torch
.
device
(
device
)
self
.
run_
device
=
run_
device
# self.num_tokens = num_tokens * 4
self
.
num_tokens_x4
=
num_tokens
*
4
self
.
audio_pe
=
nn
.
Parameter
(
torch
.
randn
(
self
.
num_tokens_x4
,
mlp_dims
[
-
1
]
//
num_tokens
)
*
0.02
)
...
...
@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module):
@
torch
.
no_grad
()
def
forward_audio_proj
(
self
,
audio_feat
,
latent_frame
):
if
self
.
cpu_offload
:
self
.
audio_proj
.
to
(
self
.
device
)
self
.
audio_proj
.
to
(
self
.
run_
device
)
x
=
self
.
audio_proj
(
audio_feat
,
latent_frame
)
x
=
self
.
rearange_audio_features
(
x
)
x
=
x
+
self
.
audio_pe
.
to
(
self
.
device
)
x
=
x
+
self
.
audio_pe
.
to
(
self
.
run_
device
)
if
self
.
cpu_offload
:
self
.
audio_proj
.
to
(
"cpu"
)
return
x
lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
View file @
4c0a9a0d
...
...
@@ -5,14 +5,15 @@ from lightx2v.utils.envs import *
class
SekoAudioEncoderModel
:
def
__init__
(
self
,
model_path
,
audio_sr
,
cpu_offload
,
device
):
def
__init__
(
self
,
model_path
,
audio_sr
,
cpu_offload
,
run_
device
):
self
.
model_path
=
model_path
self
.
audio_sr
=
audio_sr
self
.
cpu_offload
=
cpu_offload
if
self
.
cpu_offload
:
self
.
device
=
torch
.
device
(
"cpu"
)
else
:
self
.
device
=
torch
.
device
(
device
)
self
.
device
=
torch
.
device
(
run_device
)
self
.
run_device
=
run_device
self
.
load
()
def
load
(
self
):
...
...
@@ -26,13 +27,13 @@ class SekoAudioEncoderModel:
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
"cpu"
)
def
to_cuda
(
self
):
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
self
.
device
)
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
self
.
run_
device
)
@
torch
.
no_grad
()
def
infer
(
self
,
audio_segment
):
audio_feat
=
self
.
audio_feature_extractor
(
audio_segment
,
sampling_rate
=
self
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
to
(
self
.
device
).
to
(
dtype
=
GET_DTYPE
())
audio_feat
=
self
.
audio_feature_extractor
(
audio_segment
,
sampling_rate
=
self
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
to
(
self
.
run_
device
).
to
(
dtype
=
GET_DTYPE
())
if
self
.
cpu_offload
:
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
self
.
device
)
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
self
.
run_
device
)
audio_feat
=
self
.
audio_feature_encoder
(
audio_feat
,
return_dict
=
True
).
last_hidden_state
if
self
.
cpu_offload
:
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
"cpu"
)
...
...
lightx2v/models/input_encoders/hf/wan/t5/model.py
View file @
4c0a9a0d
...
...
@@ -744,7 +744,8 @@ class T5EncoderModel:
self
,
text_len
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cpu"
),
device
=
torch
.
device
(
"cuda"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
tokenizer_path
=
None
,
shard_fn
=
None
,
...
...
@@ -757,6 +758,7 @@ class T5EncoderModel:
self
.
text_len
=
text_len
self
.
dtype
=
dtype
self
.
device
=
device
self
.
run_device
=
run_device
if
t5_quantized_ckpt
is
not
None
and
t5_quantized
:
self
.
checkpoint_path
=
t5_quantized_ckpt
else
:
...
...
@@ -805,8 +807,8 @@ class T5EncoderModel:
def
infer
(
self
,
texts
):
ids
,
mask
=
self
.
tokenizer
(
texts
,
return_mask
=
True
,
add_special_tokens
=
True
)
ids
=
ids
.
to
(
self
.
device
)
mask
=
mask
.
to
(
self
.
device
)
ids
=
ids
.
to
(
self
.
run_
device
)
mask
=
mask
.
to
(
self
.
run_
device
)
seq_lens
=
mask
.
gt
(
0
).
sum
(
dim
=
1
).
long
()
with
torch
.
no_grad
():
...
...
lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
View file @
4c0a9a0d
...
...
@@ -428,7 +428,6 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class
CLIPModel
:
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
cpu_offload
=
False
,
use_31_block
=
True
,
load_from_rank0
=
False
,
run_device
=
torch
.
device
(
"cuda"
)):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
run_device
=
run_device
self
.
quantized
=
clip_quantized
self
.
cpu_offload
=
cpu_offload
...
...
lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
View file @
4c0a9a0d
...
...
@@ -68,7 +68,7 @@ class HunyuanVideo15PreInfer:
self
.
heads_num
=
config
[
"heads_num"
]
self
.
frequency_embedding_size
=
256
self
.
max_period
=
10000
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
run_
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
...
...
@@ -155,7 +155,7 @@ class HunyuanVideo15PreInfer:
byt5_txt
=
byt5_txt
+
weights
.
cond_type_embedding
.
apply
(
torch
.
ones_like
(
byt5_txt
[:,
:,
0
],
device
=
byt5_txt
.
device
,
dtype
=
torch
.
long
))
txt
,
text_mask
=
self
.
reorder_txt_token
(
byt5_txt
,
txt
,
byt5_text_mask
,
text_mask
,
zero_feat
=
True
)
siglip_output
=
siglip_output
+
weights
.
cond_type_embedding
.
apply
(
2
*
torch
.
ones_like
(
siglip_output
[:,
:,
0
],
dtype
=
torch
.
long
,
device
=
self
.
device
))
siglip_output
=
siglip_output
+
weights
.
cond_type_embedding
.
apply
(
2
*
torch
.
ones_like
(
siglip_output
[:,
:,
0
],
dtype
=
torch
.
long
,
device
=
self
.
run_
device
))
txt
,
text_mask
=
self
.
reorder_txt_token
(
siglip_output
,
txt
,
siglip_mask
,
text_mask
)
txt
=
txt
[:,
:
text_mask
.
sum
(),
:]
...
...
lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
View file @
4c0a9a0d
...
...
@@ -100,7 +100,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
self
.
config
=
config
self
.
double_blocks_num
=
config
[
"mm_double_blocks_depth"
]
self
.
heads_num
=
config
[
"heads_num"
]
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
run_
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
if
self
.
config
[
"seq_parallel"
]:
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
self
.
seq_p_fp8_comm
=
self
.
config
[
"parallel"
].
get
(
"seq_p_fp8_comm"
,
False
)
...
...
@@ -222,7 +222,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
key
=
torch
.
cat
([
img_k
,
txt_k
],
dim
=
1
)
value
=
torch
.
cat
([
img_v
,
txt_v
],
dim
=
1
)
seqlen
=
query
.
shape
[
1
]
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
seqlen
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
self
.
device
,
non_blocking
=
True
)
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
seqlen
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
self
.
run_
device
,
non_blocking
=
True
)
if
self
.
config
[
"seq_parallel"
]:
attn_out
=
weights
.
self_attention_parallel
.
apply
(
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
4c0a9a0d
...
...
@@ -9,6 +9,10 @@ from .triton_ops import fuse_scale_shift_kernel
from
.utils
import
apply_wan_rope_with_chunk
,
apply_wan_rope_with_flashinfer
,
apply_wan_rope_with_torch
def
modulate
(
x
,
scale
,
shift
):
return
x
*
(
1
+
scale
.
squeeze
())
+
shift
.
squeeze
()
class
WanTransformerInfer
(
BaseTransformerInfer
):
def
__init__
(
self
,
config
):
self
.
config
=
config
...
...
@@ -21,6 +25,10 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
head_dim
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
parallel_attention
=
None
if
self
.
config
.
get
(
"modulate_type"
,
"triton"
)
==
"triton"
:
self
.
modulate_func
=
fuse_scale_shift_kernel
else
:
self
.
modulate_func
=
modulate
if
self
.
config
.
get
(
"rope_type"
,
"flashinfer"
)
==
"flashinfer"
:
if
self
.
config
.
get
(
"rope_chunk"
,
False
):
self
.
apply_rope_func
=
partial
(
apply_wan_rope_with_chunk
,
chunk_size
=
self
.
config
.
get
(
"rope_chunk_size"
,
100
),
rope_func
=
apply_wan_rope_with_flashinfer
)
...
...
@@ -146,7 +154,7 @@ class WanTransformerInfer(BaseTransformerInfer):
norm1_out
=
phase
.
norm1
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm1_out
=
norm1_out
.
to
(
self
.
sensitive_layer_dtype
)
norm1_out
=
fuse_scale_shift_kernel
(
norm1_out
,
scale
=
scale_msa
,
shift
=
shift_msa
).
squeeze
(
0
)
norm1_out
=
self
.
modulate_func
(
norm1_out
,
scale
=
scale_msa
,
shift
=
shift_msa
).
squeeze
()
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm1_out
=
norm1_out
.
to
(
self
.
infer_dtype
)
...
...
@@ -285,7 +293,7 @@ class WanTransformerInfer(BaseTransformerInfer):
norm2_out
=
phase
.
norm2
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm2_out
=
norm2_out
.
to
(
self
.
sensitive_layer_dtype
)
norm2_out
=
fuse_scale_shift_kernel
(
norm2_out
,
scale
=
c_scale_msa
,
shift
=
c_shift_msa
).
squeeze
(
0
)
norm2_out
=
self
.
modulate_func
(
norm2_out
,
scale
=
c_scale_msa
,
shift
=
c_shift_msa
).
squeeze
()
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm2_out
=
norm2_out
.
to
(
self
.
infer_dtype
)
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
4c0a9a0d
import
torch
import
torch.distributed
as
dist
from
flashinfer.rope
import
apply_rope_with_cos_sin_cache_inplace
try
:
from
flashinfer.rope
import
apply_rope_with_cos_sin_cache_inplace
except
ImportError
:
apply_rope_with_cos_sin_cache_inplace
=
None
from
lightx2v.utils.envs
import
*
...
...
lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
View file @
4c0a9a0d
...
...
@@ -71,7 +71,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if
qwen25vl_offload
:
qwen25vl_device
=
torch
.
device
(
"cpu"
)
else
:
qwen25vl_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
qwen25vl_device
=
torch
.
device
(
self
.
run_device
)
qwen25vl_quantized
=
self
.
config
.
get
(
"qwen25vl_quantized"
,
False
)
qwen25vl_quant_scheme
=
self
.
config
.
get
(
"qwen25vl_quant_scheme"
,
None
)
...
...
@@ -82,6 +82,7 @@ class HunyuanVideo15Runner(DefaultRunner):
text_encoder
=
Qwen25VL_TextEncoder
(
dtype
=
torch
.
float16
,
device
=
qwen25vl_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
text_encoder_path
,
cpu_offload
=
qwen25vl_offload
,
qwen25vl_quantized
=
qwen25vl_quantized
,
...
...
@@ -93,9 +94,9 @@ class HunyuanVideo15Runner(DefaultRunner):
if
byt5_offload
:
byt5_device
=
torch
.
device
(
"cpu"
)
else
:
byt5_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
byt5_device
=
torch
.
device
(
self
.
run_device
)
byt5
=
ByT5TextEncoder
(
config
=
self
.
config
,
device
=
byt5_device
,
checkpoint_path
=
self
.
config
[
"model_path"
],
cpu_offload
=
byt5_offload
)
byt5
=
ByT5TextEncoder
(
config
=
self
.
config
,
device
=
byt5_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
self
.
config
[
"model_path"
],
cpu_offload
=
byt5_offload
)
text_encoders
=
[
text_encoder
,
byt5
]
return
text_encoders
...
...
@@ -229,10 +230,11 @@ class HunyuanVideo15Runner(DefaultRunner):
if
siglip_offload
:
siglip_device
=
torch
.
device
(
"cpu"
)
else
:
siglip_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
siglip_device
=
torch
.
device
(
self
.
run_device
)
image_encoder
=
SiglipVisionEncoder
(
config
=
self
.
config
,
device
=
siglip_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
self
.
config
[
"model_path"
],
cpu_offload
=
siglip_offload
,
)
...
...
@@ -244,7 +246,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_config
=
{
"checkpoint_path"
:
self
.
config
[
"model_path"
],
...
...
@@ -263,7 +265,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_config
=
{
"checkpoint_path"
:
self
.
config
[
"model_path"
],
...
...
@@ -273,7 +275,7 @@ class HunyuanVideo15Runner(DefaultRunner):
}
if
self
.
config
.
get
(
"use_tae"
,
False
):
tae_path
=
self
.
config
[
"tae_path"
]
vae_decoder
=
self
.
tae_cls
(
vae_path
=
tae_path
,
dtype
=
GET_DTYPE
()).
to
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
vae_decoder
=
self
.
tae_cls
(
vae_path
=
tae_path
,
dtype
=
GET_DTYPE
()).
to
(
self
.
run_device
)
else
:
vae_decoder
=
self
.
vae_cls
(
**
vae_config
)
return
vae_decoder
...
...
@@ -348,7 +350,7 @@ class HunyuanVideo15Runner(DefaultRunner):
self
.
model_sr
.
scheduler
.
step_post
()
del
self
.
inputs_sr
torch_ext_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
torch_ext_module
=
getattr
(
torch
,
self
.
run_device
)
torch_ext_module
.
empty_cache
()
self
.
config_sr
[
"is_sr_running"
]
=
False
...
...
@@ -367,10 +369,10 @@ class HunyuanVideo15Runner(DefaultRunner):
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
# vision_states is all zero, because we don't have any image input
siglip_output
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
self
.
config
[
"hidden_size"
],
dtype
=
torch
.
bfloat16
).
to
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
siglip_mask
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
))
siglip_output
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
self
.
config
[
"hidden_size"
],
dtype
=
torch
.
bfloat16
).
to
(
self
.
run_device
)
siglip_mask
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
self
.
run_device
))
torch_ext_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
torch_ext_module
=
getattr
(
torch
,
self
.
run_device
)
torch_ext_module
.
empty_cache
()
gc
.
collect
()
return
{
...
...
@@ -398,7 +400,7 @@ class HunyuanVideo15Runner(DefaultRunner):
siglip_output
,
siglip_mask
=
self
.
run_image_encoder
(
img_ori
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
cond_latents
=
self
.
run_vae_encoder
(
img_ori
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
torch_ext_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
torch_ext_module
=
getattr
(
torch
,
self
.
run_device
)
torch_ext_module
.
empty_cache
()
gc
.
collect
()
return
{
...
...
@@ -425,9 +427,9 @@ class HunyuanVideo15Runner(DefaultRunner):
target_height
=
self
.
target_height
input_image_np
=
self
.
resize_and_center_crop
(
first_frame
,
target_width
=
target_width
,
target_height
=
target_height
)
vision_states
=
self
.
image_encoder
.
encode_images
(
input_image_np
).
last_hidden_state
.
to
(
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
),
dtype
=
torch
.
bfloat16
)
vision_states
=
self
.
image_encoder
.
encode_images
(
input_image_np
).
last_hidden_state
.
to
(
device
=
torch
.
device
(
self
.
run_device
),
dtype
=
torch
.
bfloat16
)
image_encoder_output
=
self
.
image_encoder
.
infer
(
vision_states
)
image_encoder_mask
=
torch
.
ones
((
1
,
image_encoder_output
.
shape
[
1
]),
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
))
image_encoder_mask
=
torch
.
ones
((
1
,
image_encoder_output
.
shape
[
1
]),
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
self
.
run_device
))
return
image_encoder_output
,
image_encoder_mask
def
resize_and_center_crop
(
self
,
image
,
target_width
,
target_height
):
...
...
@@ -478,7 +480,6 @@ class HunyuanVideo15Runner(DefaultRunner):
]
)
ref_images_pixel_values
=
ref_image_transform
(
first_frame
).
unsqueeze
(
0
).
unsqueeze
(
2
).
to
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
ref_images_pixel_values
=
ref_image_transform
(
first_frame
).
unsqueeze
(
0
).
unsqueeze
(
2
).
to
(
self
.
run_device
)
cond_latents
=
self
.
vae_encoder
.
encode
(
ref_images_pixel_values
.
to
(
GET_DTYPE
()))
return
cond_latents
lightx2v/models/runners/qwen_image/qwen_image_runner.py
View file @
4c0a9a0d
...
...
@@ -85,8 +85,8 @@ class QwenImageRunner(DefaultRunner):
def
_run_input_encoder_local_t2i
(
self
):
prompt
=
self
.
input_info
.
prompt
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
neg_prompt
=
self
.
input_info
.
negative_prompt
)
if
hasattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
):
torch_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
if
hasattr
(
torch
,
self
.
run_device
):
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
gc
.
collect
()
return
{
...
...
@@ -102,7 +102,7 @@ class QwenImageRunner(DefaultRunner):
if
GET_RECORDER_MODE
():
width
,
height
=
img_ori
.
size
monitor_cli
.
lightx2v_input_image_len
.
observe
(
width
*
height
)
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
to
(
self
.
init
_device
)
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
to
(
self
.
run
_device
)
self
.
input_info
.
original_size
.
append
(
img_ori
.
size
)
return
img
,
img_ori
...
...
@@ -121,8 +121,8 @@ class QwenImageRunner(DefaultRunner):
for
vae_image
in
text_encoder_output
[
"image_info"
][
"vae_image_list"
]:
image_encoder_output
=
self
.
run_vae_encoder
(
image
=
vae_image
)
image_encoder_output_list
.
append
(
image_encoder_output
)
if
hasattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
):
torch_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
if
hasattr
(
torch
,
self
.
run_device
):
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
gc
.
collect
()
return
{
...
...
@@ -238,8 +238,8 @@ class QwenImageRunner(DefaultRunner):
images
=
self
.
vae
.
decode
(
latents
,
self
.
input_info
)
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
vae_decoder
if
hasattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
):
torch_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
if
hasattr
(
torch
,
self
.
run_device
):
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
gc
.
collect
()
return
images
...
...
@@ -259,8 +259,8 @@ class QwenImageRunner(DefaultRunner):
image
.
save
(
f
"
{
input_info
.
save_result_path
}
"
)
del
latents
,
generator
if
hasattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
):
torch_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
if
hasattr
(
torch
,
self
.
run_device
):
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
gc
.
collect
()
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
4c0a9a0d
...
...
@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def
load_audio_encoder
(
self
):
audio_encoder_path
=
self
.
config
.
get
(
"audio_encoder_path"
,
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"TencentGameMate-chinese-hubert-large"
))
audio_encoder_offload
=
self
.
config
.
get
(
"audio_encoder_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
,
False
))
model
=
SekoAudioEncoderModel
(
audio_encoder_path
,
self
.
config
[
"audio_sr"
],
audio_encoder_offload
,
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
))
model
=
SekoAudioEncoderModel
(
audio_encoder_path
,
self
.
config
[
"audio_sr"
],
audio_encoder_offload
,
run_
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
))
return
model
def
load_audio_adapter
(
self
):
...
...
@@ -843,7 +843,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if
audio_adapter_offload
:
device
=
torch
.
device
(
"cpu"
)
else
:
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
device
=
torch
.
device
(
self
.
run_device
)
audio_adapter
=
AudioAdapter
(
attention_head_dim
=
self
.
config
[
"dim"
]
//
self
.
config
[
"num_heads"
],
num_attention_heads
=
self
.
config
[
"num_heads"
],
...
...
@@ -856,7 +856,7 @@ class WanAudioRunner(WanRunner): # type:ignore
quantized
=
self
.
config
.
get
(
"adapter_quantized"
,
False
),
quant_scheme
=
self
.
config
.
get
(
"adapter_quant_scheme"
,
None
),
cpu_offload
=
audio_adapter_offload
,
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
,
run_
device
=
self
.
run_device
,
)
audio_adapter
.
to
(
device
)
...
...
@@ -892,7 +892,7 @@ class Wan22AudioRunner(WanAudioRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
"Wan2.2_VAE.pth"
),
"device"
:
vae_device
,
...
...
@@ -909,7 +909,7 @@ class Wan22AudioRunner(WanAudioRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
"Wan2.2_VAE.pth"
),
"device"
:
vae_device
,
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
4c0a9a0d
...
...
@@ -65,7 +65,7 @@ class WanRunner(DefaultRunner):
if
clip_offload
:
clip_device
=
torch
.
device
(
"cpu"
)
else
:
clip_device
=
torch
.
device
(
self
.
init
_device
)
clip_device
=
torch
.
device
(
self
.
run
_device
)
# quant_config
clip_quantized
=
self
.
config
.
get
(
"clip_quantized"
,
False
)
if
clip_quantized
:
...
...
@@ -123,6 +123,7 @@ class WanRunner(DefaultRunner):
text_len
=
self
.
config
[
"text_len"
],
dtype
=
torch
.
bfloat16
,
device
=
t5_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
t5_original_ckpt
,
tokenizer_path
=
tokenizer_path
,
shard_fn
=
None
,
...
...
@@ -141,7 +142,7 @@ class WanRunner(DefaultRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
init
_device
)
vae_device
=
torch
.
device
(
self
.
run
_device
)
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
self
.
vae_name
),
...
...
@@ -320,7 +321,7 @@ class WanRunner(DefaultRunner):
self
.
config
[
"target_video_length"
],
lat_h
,
lat_w
,
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
),
device
=
torch
.
device
(
self
.
run_device
),
)
if
last_frame
is
not
None
:
msk
[:,
1
:
-
1
]
=
0
...
...
@@ -342,7 +343,7 @@ class WanRunner(DefaultRunner):
torch
.
nn
.
functional
.
interpolate
(
last_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
],
dim
=
1
,
).
to
(
self
.
init
_device
)
).
to
(
self
.
run
_device
)
else
:
vae_input
=
torch
.
concat
(
[
...
...
@@ -350,7 +351,7 @@ class WanRunner(DefaultRunner):
torch
.
zeros
(
3
,
self
.
config
[
"target_video_length"
]
-
1
,
h
,
w
),
],
dim
=
1
,
).
to
(
self
.
init
_device
)
).
to
(
self
.
run
_device
)
vae_encoder_out
=
self
.
vae_encoder
.
encode
(
vae_input
.
unsqueeze
(
0
).
to
(
GET_DTYPE
()))
...
...
@@ -533,7 +534,7 @@ class Wan22DenseRunner(WanRunner):
assert
img
.
width
==
ow
and
img
.
height
==
oh
# to tensor
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
to
(
self
.
init
_device
).
unsqueeze
(
1
)
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
to
(
self
.
run
_device
).
unsqueeze
(
1
)
vae_encoder_out
=
self
.
get_vae_encoder_output
(
img
)
latent_w
,
latent_h
=
ow
//
self
.
config
[
"vae_stride"
][
2
],
oh
//
self
.
config
[
"vae_stride"
][
1
]
latent_shape
=
self
.
get_latent_shape_with_lat_hw
(
latent_h
,
latent_w
)
...
...
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
View file @
4c0a9a0d
...
...
@@ -271,8 +271,8 @@ def get_1d_rotary_pos_embed(
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
# [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
device
=
kwds
.
get
(
"device"
,
"cuda"
)
freqs
=
torch
.
outer
(
pos
*
interpolation_factor
,
freqs
).
to
(
device
)
# [S, D/2]
run_
device
=
kwds
.
get
(
"
run_
device"
,
"cuda"
)
freqs
=
torch
.
outer
(
pos
*
interpolation_factor
,
freqs
).
to
(
run_
device
)
# [S, D/2]
if
use_real
:
freqs_cos
=
freqs
.
cos
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
freqs_sin
=
freqs
.
sin
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
...
...
lightx2v/models/schedulers/hunyuan_video/scheduler.py
View file @
4c0a9a0d
...
...
@@ -11,7 +11,7 @@ from .posemb_layers import get_nd_rotary_pos_embed
class
HunyuanVideo15Scheduler
(
BaseScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
run_
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
reverse
=
True
self
.
num_train_timesteps
=
1000
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
...
...
@@ -25,13 +25,13 @@ class HunyuanVideo15Scheduler(BaseScheduler):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
bfloat16
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_
device
,
shift
=
self
.
sample_shift
)
self
.
multitask_mask
=
self
.
get_task_mask
(
self
.
config
[
"task"
],
latent_shape
[
-
3
])
self
.
cond_latents_concat
,
self
.
mask_concat
=
self
.
_prepare_cond_latents_and_mask
(
self
.
config
[
"task"
],
image_encoder_output
[
"cond_latents"
],
self
.
latents
,
self
.
multitask_mask
,
self
.
reorg_token
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
3
]))
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
bfloat16
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_
device
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
1
,
latent_shape
[
0
],
...
...
@@ -39,7 +39,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
latent_shape
[
2
],
latent_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
device
=
self
.
run_
device
,
generator
=
self
.
generator
,
)
...
...
@@ -127,7 +127,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
if
rope_dim_list
is
None
:
rope_dim_list
=
[
head_dim
//
target_ndim
for
_
in
range
(
target_ndim
)]
assert
sum
(
rope_dim_list
)
==
head_dim
,
"sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos
,
freqs_sin
=
get_nd_rotary_pos_embed
(
rope_dim_list
,
rope_sizes
,
theta
=
self
.
config
[
"rope_theta"
],
use_real
=
True
,
theta_rescale_factor
=
1
,
device
=
self
.
device
)
freqs_cos
,
freqs_sin
=
get_nd_rotary_pos_embed
(
rope_dim_list
,
rope_sizes
,
theta
=
self
.
config
[
"rope_theta"
],
use_real
=
True
,
theta_rescale_factor
=
1
,
device
=
self
.
run_
device
)
cos_half
=
freqs_cos
[:,
::
2
].
contiguous
()
sin_half
=
freqs_sin
[:,
::
2
].
contiguous
()
cos_sin
=
torch
.
cat
([
cos_half
,
sin_half
],
dim
=-
1
)
...
...
@@ -151,7 +151,7 @@ class HunyuanVideo15SRScheduler(HunyuanVideo15Scheduler):
dtype
=
lq_latents
.
dtype
device
=
lq_latents
.
device
self
.
prepare_latents
(
seed
,
latent_shape
,
lq_latents
,
dtype
=
dtype
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_
device
,
shift
=
self
.
sample_shift
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
3
]))
tgt_shape
=
latent_shape
[
-
2
:]
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment