Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
4c0a9a0d
"LinuxGUI/vscode:/vscode.git/clone" did not exist on "1840cb8f38f3ee14a50a5779ecbf7897f8f02bd9"
Unverified
Commit
4c0a9a0d
authored
Nov 27, 2025
by
Gu Shiqiao
Committed by
GitHub
Nov 27, 2025
Browse files
Fix device bugs (#527)
parent
fbb19ffc
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
142 additions
and
77 deletions
+142
-77
configs/seko_talk/mlu/seko_talk_bf16.json
configs/seko_talk/mlu/seko_talk_bf16.json
+19
-0
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+19
-0
lightx2v/infer.py
lightx2v/infer.py
+5
-1
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
+6
-5
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
+6
-5
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
+7
-2
lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
...tx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
+4
-4
lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
...tx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
+6
-5
lightx2v/models/input_encoders/hf/wan/t5/model.py
lightx2v/models/input_encoders/hf/wan/t5/model.py
+5
-3
lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
+0
-1
lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
+2
-2
lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
.../models/networks/hunyuan_video/infer/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+10
-2
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+5
-1
lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
...v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
+17
-16
lightx2v/models/runners/qwen_image/qwen_image_runner.py
lightx2v/models/runners/qwen_image/qwen_image_runner.py
+9
-9
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+5
-5
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+7
-6
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
+2
-2
lightx2v/models/schedulers/hunyuan_video/scheduler.py
lightx2v/models/schedulers/hunyuan_video/scheduler.py
+6
-6
No files found.
configs/seko_talk/mlu/seko_talk_bf16.json
0 → 100644
View file @
4c0a9a0d
{
"infer_steps"
:
4
,
"target_fps"
:
16
,
"video_duration"
:
360
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"resize_mode"
:
"adaptive"
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"sample_guide_scale"
:
1.0
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"use_31_block"
:
false
,
"run_device"
:
"mlu"
,
"rope_type"
:
"torch"
,
"modulate_type"
:
"torch"
}
lightx2v/common/ops/mm/mm_weight.py
View file @
4c0a9a0d
...
...
@@ -442,6 +442,25 @@ class MMWeightQuantTemplate(MMWeightTemplate):
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
if
self
.
bias_name
is
not
None
:
if
self
.
create_cuda_buffer
:
# move to cuda buffer
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
else
:
device
=
weight_dict
[
self
.
bias_name
].
device
if
device
.
type
==
"cuda"
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
elif
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
def
load_fp8_perblock128_sym
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
self
.
weight
=
weight_dict
[
self
.
weight_name
]
...
...
lightx2v/infer.py
View file @
4c0a9a0d
...
...
@@ -3,7 +3,11 @@ import argparse
import
torch
import
torch.distributed
as
dist
from
loguru
import
logger
from
torch.distributed
import
ProcessGroupNCCL
try
:
from
torch.distributed
import
ProcessGroupNCCL
except
ImportError
:
ProcessGroupNCCL
=
None
from
lightx2v.common.ops
import
*
from
lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner
import
HunyuanVideo15DistillRunner
# noqa: F401
...
...
lightx2v/models/input_encoders/hf/hunyuan15/byt5/model.py
View file @
4c0a9a0d
...
...
@@ -159,13 +159,14 @@ class ByT5TextEncoder:
self
,
config
,
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
byt5_max_length
=
256
,
cpu_offload
=
False
,
):
self
.
cpu_offload
=
cpu_offload
self
.
config
=
config
self
.
device
=
device
self
.
run_
device
=
run_
device
self
.
byt5_max_length
=
byt5_max_length
self
.
enable_cfg
=
config
.
get
(
"enable_cfg"
,
False
)
byT5_google_path
=
os
.
path
.
join
(
checkpoint_path
,
"text_encoder"
,
"byt5-small"
)
...
...
@@ -300,12 +301,12 @@ class ByT5TextEncoder:
negative_masks
=
[]
for
prompt
in
prompt_list
:
pos_emb
,
pos_mask
=
self
.
_process_single_byt5_prompt
(
prompt
,
self
.
device
)
pos_emb
,
pos_mask
=
self
.
_process_single_byt5_prompt
(
prompt
,
self
.
run_
device
)
positive_embeddings
.
append
(
pos_emb
)
positive_masks
.
append
(
pos_mask
)
if
self
.
enable_cfg
:
# TODO: 把cfg拆出去,更适合并行
neg_emb
,
neg_mask
=
self
.
_process_single_byt5_prompt
(
""
,
self
.
device
)
neg_emb
,
neg_mask
=
self
.
_process_single_byt5_prompt
(
""
,
self
.
run_
device
)
negative_embeddings
.
append
(
neg_emb
)
negative_masks
.
append
(
neg_mask
)
...
...
@@ -327,8 +328,8 @@ class ByT5TextEncoder:
@
torch
.
no_grad
()
def
infer
(
self
,
prompts
):
if
self
.
cpu_offload
:
self
.
byt5_model
=
self
.
byt5_model
.
to
(
self
.
device
)
self
.
byt5_mapper
=
self
.
byt5_mapper
.
to
(
self
.
device
)
self
.
byt5_model
=
self
.
byt5_model
.
to
(
self
.
run_
device
)
self
.
byt5_mapper
=
self
.
byt5_mapper
.
to
(
self
.
run_
device
)
byt5_embeddings
,
byt5_masks
=
self
.
_prepare_byt5_embeddings
(
prompts
)
byt5_features
=
self
.
byt5_mapper
(
byt5_embeddings
.
to
(
torch
.
bfloat16
))
if
self
.
cpu_offload
:
...
...
lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
View file @
4c0a9a0d
...
...
@@ -552,6 +552,7 @@ class Qwen25VL_TextEncoder:
text_len
=
1000
,
dtype
=
torch
.
float16
,
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
cpu_offload
=
False
,
qwen25vl_quantized
=
False
,
...
...
@@ -560,7 +561,7 @@ class Qwen25VL_TextEncoder:
):
self
.
text_len
=
text_len
self
.
dtype
=
dtype
self
.
device
=
device
self
.
run_
device
=
run_
device
self
.
cpu_offload
=
cpu_offload
self
.
qwen25vl_quantized
=
qwen25vl_quantized
self
.
qwen25vl_quant_scheme
=
qwen25vl_quant_scheme
...
...
@@ -589,20 +590,20 @@ class Qwen25VL_TextEncoder:
def
infer
(
self
,
texts
):
if
self
.
cpu_offload
:
self
.
text_encoder
=
self
.
text_encoder
.
to
(
self
.
device
)
self
.
text_encoder
=
self
.
text_encoder
.
to
(
self
.
run_
device
)
text_inputs
=
self
.
text_encoder
.
text2tokens
(
texts
,
data_type
=
"video"
,
max_length
=
self
.
text_len
)
prompt_outputs
=
self
.
text_encoder
.
encode
(
text_inputs
,
data_type
=
"video"
,
device
=
self
.
device
)
prompt_outputs
=
self
.
text_encoder
.
encode
(
text_inputs
,
data_type
=
"video"
,
device
=
self
.
run_
device
)
if
self
.
cpu_offload
:
self
.
text_encoder
=
self
.
text_encoder
.
to
(
"cpu"
)
prompt_embeds
=
prompt_outputs
.
hidden_state
attention_mask
=
prompt_outputs
.
attention_mask
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
to
(
self
.
device
)
attention_mask
=
attention_mask
.
to
(
self
.
run_
device
)
_
,
seq_len
=
attention_mask
.
shape
attention_mask
=
attention_mask
.
repeat
(
1
,
self
.
num_videos_per_prompt
)
attention_mask
=
attention_mask
.
view
(
self
.
num_videos_per_prompt
,
seq_len
)
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
run_
device
)
seq_len
=
prompt_embeds
.
shape
[
1
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
...
...
lightx2v/models/input_encoders/hf/hunyuan15/siglip/model.py
View file @
4c0a9a0d
...
...
@@ -95,6 +95,7 @@ class VisionEncoder(nn.Module):
output_key
:
Optional
[
str
]
=
None
,
logger
=
None
,
device
=
None
,
run_device
=
None
,
cpu_offload
=
False
,
):
super
().
__init__
()
...
...
@@ -120,6 +121,7 @@ class VisionEncoder(nn.Module):
)
self
.
dtype
=
self
.
model
.
dtype
self
.
device
=
self
.
model
.
device
self
.
run_device
=
run_device
self
.
processor
,
self
.
processor_path
=
load_image_processor
(
processor_type
=
self
.
processor_type
,
...
...
@@ -175,7 +177,7 @@ class VisionEncoder(nn.Module):
if
isinstance
(
images
,
np
.
ndarray
):
# Preprocess images if they're numpy arrays
preprocessed
=
self
.
processor
.
preprocess
(
images
=
images
,
return_tensors
=
"pt"
).
to
(
device
=
self
.
device
,
dtype
=
self
.
model
.
dtype
)
preprocessed
=
self
.
processor
.
preprocess
(
images
=
images
,
return_tensors
=
"pt"
).
to
(
device
=
self
.
run_
device
,
dtype
=
self
.
model
.
dtype
)
else
:
# Assume already preprocessed
preprocessed
=
images
...
...
@@ -230,11 +232,13 @@ class SiglipVisionEncoder:
self
,
config
,
device
=
torch
.
device
(
"cpu"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
cpu_offload
=
False
,
):
self
.
config
=
config
self
.
device
=
device
self
.
run_device
=
run_device
self
.
cpu_offload
=
cpu_offload
self
.
vision_states_dim
=
1152
vision_encoder_path
=
os
.
path
.
join
(
checkpoint_path
,
"vision_encoder"
,
"siglip"
)
...
...
@@ -248,6 +252,7 @@ class SiglipVisionEncoder:
output_key
=
None
,
logger
=
None
,
device
=
self
.
device
,
run_device
=
self
.
run_device
,
cpu_offload
=
self
.
cpu_offload
,
)
...
...
@@ -265,7 +270,7 @@ class SiglipVisionEncoder:
@
torch
.
no_grad
()
def
infer
(
self
,
vision_states
):
if
self
.
cpu_offload
:
self
.
vision_in
=
self
.
vision_in
.
to
(
"cuda"
)
self
.
vision_in
=
self
.
vision_in
.
to
(
self
.
run_device
)
vision_states
=
self
.
vision_in
(
vision_states
)
if
self
.
cpu_offload
:
self
.
vision_in
=
self
.
vision_in
.
to
(
"cpu"
)
...
...
lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
View file @
4c0a9a0d
...
...
@@ -252,7 +252,7 @@ class AudioAdapter(nn.Module):
quantized
:
bool
=
False
,
quant_scheme
:
str
=
None
,
cpu_offload
:
bool
=
False
,
device
=
torch
.
device
(
"c
p
u"
),
run_
device
=
torch
.
device
(
"cu
da
"
),
):
super
().
__init__
()
self
.
cpu_offload
=
cpu_offload
...
...
@@ -263,7 +263,7 @@ class AudioAdapter(nn.Module):
mlp_dims
=
mlp_dims
,
transformer_layers
=
projection_transformer_layers
,
)
self
.
device
=
torch
.
device
(
device
)
self
.
run_
device
=
run_
device
# self.num_tokens = num_tokens * 4
self
.
num_tokens_x4
=
num_tokens
*
4
self
.
audio_pe
=
nn
.
Parameter
(
torch
.
randn
(
self
.
num_tokens_x4
,
mlp_dims
[
-
1
]
//
num_tokens
)
*
0.02
)
...
...
@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module):
@
torch
.
no_grad
()
def
forward_audio_proj
(
self
,
audio_feat
,
latent_frame
):
if
self
.
cpu_offload
:
self
.
audio_proj
.
to
(
self
.
device
)
self
.
audio_proj
.
to
(
self
.
run_
device
)
x
=
self
.
audio_proj
(
audio_feat
,
latent_frame
)
x
=
self
.
rearange_audio_features
(
x
)
x
=
x
+
self
.
audio_pe
.
to
(
self
.
device
)
x
=
x
+
self
.
audio_pe
.
to
(
self
.
run_
device
)
if
self
.
cpu_offload
:
self
.
audio_proj
.
to
(
"cpu"
)
return
x
lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
View file @
4c0a9a0d
...
...
@@ -5,14 +5,15 @@ from lightx2v.utils.envs import *
class
SekoAudioEncoderModel
:
def
__init__
(
self
,
model_path
,
audio_sr
,
cpu_offload
,
device
):
def
__init__
(
self
,
model_path
,
audio_sr
,
cpu_offload
,
run_
device
):
self
.
model_path
=
model_path
self
.
audio_sr
=
audio_sr
self
.
cpu_offload
=
cpu_offload
if
self
.
cpu_offload
:
self
.
device
=
torch
.
device
(
"cpu"
)
else
:
self
.
device
=
torch
.
device
(
device
)
self
.
device
=
torch
.
device
(
run_device
)
self
.
run_device
=
run_device
self
.
load
()
def
load
(
self
):
...
...
@@ -26,13 +27,13 @@ class SekoAudioEncoderModel:
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
"cpu"
)
def
to_cuda
(
self
):
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
self
.
device
)
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
self
.
run_
device
)
@
torch
.
no_grad
()
def
infer
(
self
,
audio_segment
):
audio_feat
=
self
.
audio_feature_extractor
(
audio_segment
,
sampling_rate
=
self
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
to
(
self
.
device
).
to
(
dtype
=
GET_DTYPE
())
audio_feat
=
self
.
audio_feature_extractor
(
audio_segment
,
sampling_rate
=
self
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
to
(
self
.
run_
device
).
to
(
dtype
=
GET_DTYPE
())
if
self
.
cpu_offload
:
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
self
.
device
)
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
self
.
run_
device
)
audio_feat
=
self
.
audio_feature_encoder
(
audio_feat
,
return_dict
=
True
).
last_hidden_state
if
self
.
cpu_offload
:
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
"cpu"
)
...
...
lightx2v/models/input_encoders/hf/wan/t5/model.py
View file @
4c0a9a0d
...
...
@@ -744,7 +744,8 @@ class T5EncoderModel:
self
,
text_len
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cpu"
),
device
=
torch
.
device
(
"cuda"
),
run_device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
None
,
tokenizer_path
=
None
,
shard_fn
=
None
,
...
...
@@ -757,6 +758,7 @@ class T5EncoderModel:
self
.
text_len
=
text_len
self
.
dtype
=
dtype
self
.
device
=
device
self
.
run_device
=
run_device
if
t5_quantized_ckpt
is
not
None
and
t5_quantized
:
self
.
checkpoint_path
=
t5_quantized_ckpt
else
:
...
...
@@ -805,8 +807,8 @@ class T5EncoderModel:
def
infer
(
self
,
texts
):
ids
,
mask
=
self
.
tokenizer
(
texts
,
return_mask
=
True
,
add_special_tokens
=
True
)
ids
=
ids
.
to
(
self
.
device
)
mask
=
mask
.
to
(
self
.
device
)
ids
=
ids
.
to
(
self
.
run_
device
)
mask
=
mask
.
to
(
self
.
run_
device
)
seq_lens
=
mask
.
gt
(
0
).
sum
(
dim
=
1
).
long
()
with
torch
.
no_grad
():
...
...
lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
View file @
4c0a9a0d
...
...
@@ -428,7 +428,6 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class
CLIPModel
:
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
cpu_offload
=
False
,
use_31_block
=
True
,
load_from_rank0
=
False
,
run_device
=
torch
.
device
(
"cuda"
)):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
run_device
=
run_device
self
.
quantized
=
clip_quantized
self
.
cpu_offload
=
cpu_offload
...
...
lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
View file @
4c0a9a0d
...
...
@@ -68,7 +68,7 @@ class HunyuanVideo15PreInfer:
self
.
heads_num
=
config
[
"heads_num"
]
self
.
frequency_embedding_size
=
256
self
.
max_period
=
10000
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
run_
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
...
...
@@ -155,7 +155,7 @@ class HunyuanVideo15PreInfer:
byt5_txt
=
byt5_txt
+
weights
.
cond_type_embedding
.
apply
(
torch
.
ones_like
(
byt5_txt
[:,
:,
0
],
device
=
byt5_txt
.
device
,
dtype
=
torch
.
long
))
txt
,
text_mask
=
self
.
reorder_txt_token
(
byt5_txt
,
txt
,
byt5_text_mask
,
text_mask
,
zero_feat
=
True
)
siglip_output
=
siglip_output
+
weights
.
cond_type_embedding
.
apply
(
2
*
torch
.
ones_like
(
siglip_output
[:,
:,
0
],
dtype
=
torch
.
long
,
device
=
self
.
device
))
siglip_output
=
siglip_output
+
weights
.
cond_type_embedding
.
apply
(
2
*
torch
.
ones_like
(
siglip_output
[:,
:,
0
],
dtype
=
torch
.
long
,
device
=
self
.
run_
device
))
txt
,
text_mask
=
self
.
reorder_txt_token
(
siglip_output
,
txt
,
siglip_mask
,
text_mask
)
txt
=
txt
[:,
:
text_mask
.
sum
(),
:]
...
...
lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
View file @
4c0a9a0d
...
...
@@ -100,7 +100,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
self
.
config
=
config
self
.
double_blocks_num
=
config
[
"mm_double_blocks_depth"
]
self
.
heads_num
=
config
[
"heads_num"
]
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
run_
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
if
self
.
config
[
"seq_parallel"
]:
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
self
.
seq_p_fp8_comm
=
self
.
config
[
"parallel"
].
get
(
"seq_p_fp8_comm"
,
False
)
...
...
@@ -222,7 +222,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
key
=
torch
.
cat
([
img_k
,
txt_k
],
dim
=
1
)
value
=
torch
.
cat
([
img_v
,
txt_v
],
dim
=
1
)
seqlen
=
query
.
shape
[
1
]
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
seqlen
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
self
.
device
,
non_blocking
=
True
)
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
seqlen
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
self
.
run_
device
,
non_blocking
=
True
)
if
self
.
config
[
"seq_parallel"
]:
attn_out
=
weights
.
self_attention_parallel
.
apply
(
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
4c0a9a0d
...
...
@@ -9,6 +9,10 @@ from .triton_ops import fuse_scale_shift_kernel
from
.utils
import
apply_wan_rope_with_chunk
,
apply_wan_rope_with_flashinfer
,
apply_wan_rope_with_torch
def
modulate
(
x
,
scale
,
shift
):
return
x
*
(
1
+
scale
.
squeeze
())
+
shift
.
squeeze
()
class
WanTransformerInfer
(
BaseTransformerInfer
):
def
__init__
(
self
,
config
):
self
.
config
=
config
...
...
@@ -21,6 +25,10 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
head_dim
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
parallel_attention
=
None
if
self
.
config
.
get
(
"modulate_type"
,
"triton"
)
==
"triton"
:
self
.
modulate_func
=
fuse_scale_shift_kernel
else
:
self
.
modulate_func
=
modulate
if
self
.
config
.
get
(
"rope_type"
,
"flashinfer"
)
==
"flashinfer"
:
if
self
.
config
.
get
(
"rope_chunk"
,
False
):
self
.
apply_rope_func
=
partial
(
apply_wan_rope_with_chunk
,
chunk_size
=
self
.
config
.
get
(
"rope_chunk_size"
,
100
),
rope_func
=
apply_wan_rope_with_flashinfer
)
...
...
@@ -146,7 +154,7 @@ class WanTransformerInfer(BaseTransformerInfer):
norm1_out
=
phase
.
norm1
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm1_out
=
norm1_out
.
to
(
self
.
sensitive_layer_dtype
)
norm1_out
=
fuse_scale_shift_kernel
(
norm1_out
,
scale
=
scale_msa
,
shift
=
shift_msa
).
squeeze
(
0
)
norm1_out
=
self
.
modulate_func
(
norm1_out
,
scale
=
scale_msa
,
shift
=
shift_msa
).
squeeze
()
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm1_out
=
norm1_out
.
to
(
self
.
infer_dtype
)
...
...
@@ -285,7 +293,7 @@ class WanTransformerInfer(BaseTransformerInfer):
norm2_out
=
phase
.
norm2
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm2_out
=
norm2_out
.
to
(
self
.
sensitive_layer_dtype
)
norm2_out
=
fuse_scale_shift_kernel
(
norm2_out
,
scale
=
c_scale_msa
,
shift
=
c_shift_msa
).
squeeze
(
0
)
norm2_out
=
self
.
modulate_func
(
norm2_out
,
scale
=
c_scale_msa
,
shift
=
c_shift_msa
).
squeeze
()
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm2_out
=
norm2_out
.
to
(
self
.
infer_dtype
)
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
4c0a9a0d
import
torch
import
torch.distributed
as
dist
from
flashinfer.rope
import
apply_rope_with_cos_sin_cache_inplace
try
:
from
flashinfer.rope
import
apply_rope_with_cos_sin_cache_inplace
except
ImportError
:
apply_rope_with_cos_sin_cache_inplace
=
None
from
lightx2v.utils.envs
import
*
...
...
lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
View file @
4c0a9a0d
...
...
@@ -71,7 +71,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if
qwen25vl_offload
:
qwen25vl_device
=
torch
.
device
(
"cpu"
)
else
:
qwen25vl_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
qwen25vl_device
=
torch
.
device
(
self
.
run_device
)
qwen25vl_quantized
=
self
.
config
.
get
(
"qwen25vl_quantized"
,
False
)
qwen25vl_quant_scheme
=
self
.
config
.
get
(
"qwen25vl_quant_scheme"
,
None
)
...
...
@@ -82,6 +82,7 @@ class HunyuanVideo15Runner(DefaultRunner):
text_encoder
=
Qwen25VL_TextEncoder
(
dtype
=
torch
.
float16
,
device
=
qwen25vl_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
text_encoder_path
,
cpu_offload
=
qwen25vl_offload
,
qwen25vl_quantized
=
qwen25vl_quantized
,
...
...
@@ -93,9 +94,9 @@ class HunyuanVideo15Runner(DefaultRunner):
if
byt5_offload
:
byt5_device
=
torch
.
device
(
"cpu"
)
else
:
byt5_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
byt5_device
=
torch
.
device
(
self
.
run_device
)
byt5
=
ByT5TextEncoder
(
config
=
self
.
config
,
device
=
byt5_device
,
checkpoint_path
=
self
.
config
[
"model_path"
],
cpu_offload
=
byt5_offload
)
byt5
=
ByT5TextEncoder
(
config
=
self
.
config
,
device
=
byt5_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
self
.
config
[
"model_path"
],
cpu_offload
=
byt5_offload
)
text_encoders
=
[
text_encoder
,
byt5
]
return
text_encoders
...
...
@@ -229,10 +230,11 @@ class HunyuanVideo15Runner(DefaultRunner):
if
siglip_offload
:
siglip_device
=
torch
.
device
(
"cpu"
)
else
:
siglip_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
siglip_device
=
torch
.
device
(
self
.
run_device
)
image_encoder
=
SiglipVisionEncoder
(
config
=
self
.
config
,
device
=
siglip_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
self
.
config
[
"model_path"
],
cpu_offload
=
siglip_offload
,
)
...
...
@@ -244,7 +246,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_config
=
{
"checkpoint_path"
:
self
.
config
[
"model_path"
],
...
...
@@ -263,7 +265,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_config
=
{
"checkpoint_path"
:
self
.
config
[
"model_path"
],
...
...
@@ -273,7 +275,7 @@ class HunyuanVideo15Runner(DefaultRunner):
}
if
self
.
config
.
get
(
"use_tae"
,
False
):
tae_path
=
self
.
config
[
"tae_path"
]
vae_decoder
=
self
.
tae_cls
(
vae_path
=
tae_path
,
dtype
=
GET_DTYPE
()).
to
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
vae_decoder
=
self
.
tae_cls
(
vae_path
=
tae_path
,
dtype
=
GET_DTYPE
()).
to
(
self
.
run_device
)
else
:
vae_decoder
=
self
.
vae_cls
(
**
vae_config
)
return
vae_decoder
...
...
@@ -348,7 +350,7 @@ class HunyuanVideo15Runner(DefaultRunner):
self
.
model_sr
.
scheduler
.
step_post
()
del
self
.
inputs_sr
torch_ext_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
torch_ext_module
=
getattr
(
torch
,
self
.
run_device
)
torch_ext_module
.
empty_cache
()
self
.
config_sr
[
"is_sr_running"
]
=
False
...
...
@@ -367,10 +369,10 @@ class HunyuanVideo15Runner(DefaultRunner):
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
# vision_states is all zero, because we don't have any image input
siglip_output
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
self
.
config
[
"hidden_size"
],
dtype
=
torch
.
bfloat16
).
to
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
siglip_mask
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
))
siglip_output
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
self
.
config
[
"hidden_size"
],
dtype
=
torch
.
bfloat16
).
to
(
self
.
run_device
)
siglip_mask
=
torch
.
zeros
(
1
,
self
.
vision_num_semantic_tokens
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
self
.
run_device
))
torch_ext_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
torch_ext_module
=
getattr
(
torch
,
self
.
run_device
)
torch_ext_module
.
empty_cache
()
gc
.
collect
()
return
{
...
...
@@ -398,7 +400,7 @@ class HunyuanVideo15Runner(DefaultRunner):
siglip_output
,
siglip_mask
=
self
.
run_image_encoder
(
img_ori
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
cond_latents
=
self
.
run_vae_encoder
(
img_ori
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
input_info
)
torch_ext_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
torch_ext_module
=
getattr
(
torch
,
self
.
run_device
)
torch_ext_module
.
empty_cache
()
gc
.
collect
()
return
{
...
...
@@ -425,9 +427,9 @@ class HunyuanVideo15Runner(DefaultRunner):
target_height
=
self
.
target_height
input_image_np
=
self
.
resize_and_center_crop
(
first_frame
,
target_width
=
target_width
,
target_height
=
target_height
)
vision_states
=
self
.
image_encoder
.
encode_images
(
input_image_np
).
last_hidden_state
.
to
(
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
),
dtype
=
torch
.
bfloat16
)
vision_states
=
self
.
image_encoder
.
encode_images
(
input_image_np
).
last_hidden_state
.
to
(
device
=
torch
.
device
(
self
.
run_device
),
dtype
=
torch
.
bfloat16
)
image_encoder_output
=
self
.
image_encoder
.
infer
(
vision_states
)
image_encoder_mask
=
torch
.
ones
((
1
,
image_encoder_output
.
shape
[
1
]),
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
))
image_encoder_mask
=
torch
.
ones
((
1
,
image_encoder_output
.
shape
[
1
]),
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
self
.
run_device
))
return
image_encoder_output
,
image_encoder_mask
def
resize_and_center_crop
(
self
,
image
,
target_width
,
target_height
):
...
...
@@ -478,7 +480,6 @@ class HunyuanVideo15Runner(DefaultRunner):
]
)
ref_images_pixel_values
=
ref_image_transform
(
first_frame
).
unsqueeze
(
0
).
unsqueeze
(
2
).
to
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
ref_images_pixel_values
=
ref_image_transform
(
first_frame
).
unsqueeze
(
0
).
unsqueeze
(
2
).
to
(
self
.
run_device
)
cond_latents
=
self
.
vae_encoder
.
encode
(
ref_images_pixel_values
.
to
(
GET_DTYPE
()))
return
cond_latents
lightx2v/models/runners/qwen_image/qwen_image_runner.py
View file @
4c0a9a0d
...
...
@@ -85,8 +85,8 @@ class QwenImageRunner(DefaultRunner):
def
_run_input_encoder_local_t2i
(
self
):
prompt
=
self
.
input_info
.
prompt
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
neg_prompt
=
self
.
input_info
.
negative_prompt
)
if
hasattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
):
torch_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
if
hasattr
(
torch
,
self
.
run_device
):
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
gc
.
collect
()
return
{
...
...
@@ -102,7 +102,7 @@ class QwenImageRunner(DefaultRunner):
if
GET_RECORDER_MODE
():
width
,
height
=
img_ori
.
size
monitor_cli
.
lightx2v_input_image_len
.
observe
(
width
*
height
)
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
to
(
self
.
init
_device
)
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
to
(
self
.
run
_device
)
self
.
input_info
.
original_size
.
append
(
img_ori
.
size
)
return
img
,
img_ori
...
...
@@ -121,8 +121,8 @@ class QwenImageRunner(DefaultRunner):
for
vae_image
in
text_encoder_output
[
"image_info"
][
"vae_image_list"
]:
image_encoder_output
=
self
.
run_vae_encoder
(
image
=
vae_image
)
image_encoder_output_list
.
append
(
image_encoder_output
)
if
hasattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
):
torch_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
if
hasattr
(
torch
,
self
.
run_device
):
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
gc
.
collect
()
return
{
...
...
@@ -238,8 +238,8 @@ class QwenImageRunner(DefaultRunner):
images
=
self
.
vae
.
decode
(
latents
,
self
.
input_info
)
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
vae_decoder
if
hasattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
):
torch_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
if
hasattr
(
torch
,
self
.
run_device
):
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
gc
.
collect
()
return
images
...
...
@@ -259,8 +259,8 @@ class QwenImageRunner(DefaultRunner):
image
.
save
(
f
"
{
input_info
.
save_result_path
}
"
)
del
latents
,
generator
if
hasattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
):
torch_module
=
getattr
(
torch
,
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
if
hasattr
(
torch
,
self
.
run_device
):
torch_module
=
getattr
(
torch
,
self
.
run_device
)
torch_module
.
empty_cache
()
gc
.
collect
()
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
4c0a9a0d
...
...
@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def
load_audio_encoder
(
self
):
audio_encoder_path
=
self
.
config
.
get
(
"audio_encoder_path"
,
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"TencentGameMate-chinese-hubert-large"
))
audio_encoder_offload
=
self
.
config
.
get
(
"audio_encoder_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
,
False
))
model
=
SekoAudioEncoderModel
(
audio_encoder_path
,
self
.
config
[
"audio_sr"
],
audio_encoder_offload
,
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
))
model
=
SekoAudioEncoderModel
(
audio_encoder_path
,
self
.
config
[
"audio_sr"
],
audio_encoder_offload
,
run_
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
))
return
model
def
load_audio_adapter
(
self
):
...
...
@@ -843,7 +843,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if
audio_adapter_offload
:
device
=
torch
.
device
(
"cpu"
)
else
:
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
device
=
torch
.
device
(
self
.
run_device
)
audio_adapter
=
AudioAdapter
(
attention_head_dim
=
self
.
config
[
"dim"
]
//
self
.
config
[
"num_heads"
],
num_attention_heads
=
self
.
config
[
"num_heads"
],
...
...
@@ -856,7 +856,7 @@ class WanAudioRunner(WanRunner): # type:ignore
quantized
=
self
.
config
.
get
(
"adapter_quantized"
,
False
),
quant_scheme
=
self
.
config
.
get
(
"adapter_quant_scheme"
,
None
),
cpu_offload
=
audio_adapter_offload
,
device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
,
run_
device
=
self
.
run_device
,
)
audio_adapter
.
to
(
device
)
...
...
@@ -892,7 +892,7 @@ class Wan22AudioRunner(WanAudioRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
"Wan2.2_VAE.pth"
),
"device"
:
vae_device
,
...
...
@@ -909,7 +909,7 @@ class Wan22AudioRunner(WanAudioRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
)
vae_device
=
torch
.
device
(
self
.
run_device
)
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
"Wan2.2_VAE.pth"
),
"device"
:
vae_device
,
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
4c0a9a0d
...
...
@@ -65,7 +65,7 @@ class WanRunner(DefaultRunner):
if
clip_offload
:
clip_device
=
torch
.
device
(
"cpu"
)
else
:
clip_device
=
torch
.
device
(
self
.
init
_device
)
clip_device
=
torch
.
device
(
self
.
run
_device
)
# quant_config
clip_quantized
=
self
.
config
.
get
(
"clip_quantized"
,
False
)
if
clip_quantized
:
...
...
@@ -123,6 +123,7 @@ class WanRunner(DefaultRunner):
text_len
=
self
.
config
[
"text_len"
],
dtype
=
torch
.
bfloat16
,
device
=
t5_device
,
run_device
=
self
.
run_device
,
checkpoint_path
=
t5_original_ckpt
,
tokenizer_path
=
tokenizer_path
,
shard_fn
=
None
,
...
...
@@ -141,7 +142,7 @@ class WanRunner(DefaultRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
init
_device
)
vae_device
=
torch
.
device
(
self
.
run
_device
)
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
self
.
vae_name
),
...
...
@@ -320,7 +321,7 @@ class WanRunner(DefaultRunner):
self
.
config
[
"target_video_length"
],
lat_h
,
lat_w
,
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
)
),
device
=
torch
.
device
(
self
.
run_device
),
)
if
last_frame
is
not
None
:
msk
[:,
1
:
-
1
]
=
0
...
...
@@ -342,7 +343,7 @@ class WanRunner(DefaultRunner):
torch
.
nn
.
functional
.
interpolate
(
last_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
],
dim
=
1
,
).
to
(
self
.
init
_device
)
).
to
(
self
.
run
_device
)
else
:
vae_input
=
torch
.
concat
(
[
...
...
@@ -350,7 +351,7 @@ class WanRunner(DefaultRunner):
torch
.
zeros
(
3
,
self
.
config
[
"target_video_length"
]
-
1
,
h
,
w
),
],
dim
=
1
,
).
to
(
self
.
init
_device
)
).
to
(
self
.
run
_device
)
vae_encoder_out
=
self
.
vae_encoder
.
encode
(
vae_input
.
unsqueeze
(
0
).
to
(
GET_DTYPE
()))
...
...
@@ -533,7 +534,7 @@ class Wan22DenseRunner(WanRunner):
assert
img
.
width
==
ow
and
img
.
height
==
oh
# to tensor
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
to
(
self
.
init
_device
).
unsqueeze
(
1
)
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
to
(
self
.
run
_device
).
unsqueeze
(
1
)
vae_encoder_out
=
self
.
get_vae_encoder_output
(
img
)
latent_w
,
latent_h
=
ow
//
self
.
config
[
"vae_stride"
][
2
],
oh
//
self
.
config
[
"vae_stride"
][
1
]
latent_shape
=
self
.
get_latent_shape_with_lat_hw
(
latent_h
,
latent_w
)
...
...
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
View file @
4c0a9a0d
...
...
@@ -271,8 +271,8 @@ def get_1d_rotary_pos_embed(
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
# [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
device
=
kwds
.
get
(
"device"
,
"cuda"
)
freqs
=
torch
.
outer
(
pos
*
interpolation_factor
,
freqs
).
to
(
device
)
# [S, D/2]
run_
device
=
kwds
.
get
(
"
run_
device"
,
"cuda"
)
freqs
=
torch
.
outer
(
pos
*
interpolation_factor
,
freqs
).
to
(
run_
device
)
# [S, D/2]
if
use_real
:
freqs_cos
=
freqs
.
cos
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
freqs_sin
=
freqs
.
sin
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
...
...
lightx2v/models/schedulers/hunyuan_video/scheduler.py
View file @
4c0a9a0d
...
...
@@ -11,7 +11,7 @@ from .posemb_layers import get_nd_rotary_pos_embed
class
HunyuanVideo15Scheduler
(
BaseScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
run_
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
reverse
=
True
self
.
num_train_timesteps
=
1000
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
...
...
@@ -25,13 +25,13 @@ class HunyuanVideo15Scheduler(BaseScheduler):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
bfloat16
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_
device
,
shift
=
self
.
sample_shift
)
self
.
multitask_mask
=
self
.
get_task_mask
(
self
.
config
[
"task"
],
latent_shape
[
-
3
])
self
.
cond_latents_concat
,
self
.
mask_concat
=
self
.
_prepare_cond_latents_and_mask
(
self
.
config
[
"task"
],
image_encoder_output
[
"cond_latents"
],
self
.
latents
,
self
.
multitask_mask
,
self
.
reorg_token
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
3
]))
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
bfloat16
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_
device
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
1
,
latent_shape
[
0
],
...
...
@@ -39,7 +39,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
latent_shape
[
2
],
latent_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
device
=
self
.
run_
device
,
generator
=
self
.
generator
,
)
...
...
@@ -127,7 +127,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
if
rope_dim_list
is
None
:
rope_dim_list
=
[
head_dim
//
target_ndim
for
_
in
range
(
target_ndim
)]
assert
sum
(
rope_dim_list
)
==
head_dim
,
"sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos
,
freqs_sin
=
get_nd_rotary_pos_embed
(
rope_dim_list
,
rope_sizes
,
theta
=
self
.
config
[
"rope_theta"
],
use_real
=
True
,
theta_rescale_factor
=
1
,
device
=
self
.
device
)
freqs_cos
,
freqs_sin
=
get_nd_rotary_pos_embed
(
rope_dim_list
,
rope_sizes
,
theta
=
self
.
config
[
"rope_theta"
],
use_real
=
True
,
theta_rescale_factor
=
1
,
device
=
self
.
run_
device
)
cos_half
=
freqs_cos
[:,
::
2
].
contiguous
()
sin_half
=
freqs_sin
[:,
::
2
].
contiguous
()
cos_sin
=
torch
.
cat
([
cos_half
,
sin_half
],
dim
=-
1
)
...
...
@@ -151,7 +151,7 @@ class HunyuanVideo15SRScheduler(HunyuanVideo15Scheduler):
dtype
=
lq_latents
.
dtype
device
=
lq_latents
.
device
self
.
prepare_latents
(
seed
,
latent_shape
,
lq_latents
,
dtype
=
dtype
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_
device
,
shift
=
self
.
sample_shift
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
3
]))
tgt_shape
=
latent_shape
[
-
2
:]
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment