Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
4c0a9a0d
Unverified
Commit
4c0a9a0d
authored
Nov 27, 2025
by
Gu Shiqiao
Committed by
GitHub
Nov 27, 2025
Browse files
Fix device bugs (#527)
parent
fbb19ffc
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
54 additions
and
31 deletions
+54
-31
lightx2v/models/schedulers/qwen_image/scheduler.py
lightx2v/models/schedulers/qwen_image/scheduler.py
+6
-6
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+3
-3
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
...2v/models/schedulers/wan/changing_resolution/scheduler.py
+4
-4
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+4
-5
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
+7
-7
lightx2v/models/schedulers/wan/step_distill/scheduler.py
lightx2v/models/schedulers/wan/step_distill/scheduler.py
+1
-1
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+3
-3
scripts/seko_talk/run_seko_talk_06_offload_fp8_H100.sh
scripts/seko_talk/run_seko_talk_06_offload_fp8_H100.sh
+2
-2
scripts/seko_talk/run_seko_talk_25_mlu.sh
scripts/seko_talk/run_seko_talk_25_mlu.sh
+24
-0
No files found.
lightx2v/models/schedulers/qwen_image/scheduler.py
View file @
4c0a9a0d
...
...
@@ -133,7 +133,7 @@ class QwenImageScheduler(BaseScheduler):
self
.
scheduler
=
FlowMatchEulerDiscreteScheduler
.
from_pretrained
(
os
.
path
.
join
(
config
[
"model_path"
],
"scheduler"
))
with
open
(
os
.
path
.
join
(
config
[
"model_path"
],
"scheduler"
,
"scheduler_config.json"
),
"r"
)
as
f
:
self
.
scheduler_config
=
json
.
load
(
f
)
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
run_
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
dtype
=
torch
.
bfloat16
self
.
guidance_scale
=
1.0
...
...
@@ -176,9 +176,9 @@ class QwenImageScheduler(BaseScheduler):
shape
=
input_info
.
target_shape
width
,
height
=
shape
[
-
1
],
shape
[
-
2
]
latents
=
randn_tensor
(
shape
,
generator
=
self
.
generator
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
latents
=
randn_tensor
(
shape
,
generator
=
self
.
generator
,
device
=
self
.
run_
device
,
dtype
=
self
.
dtype
)
latents
=
self
.
_pack_latents
(
latents
,
self
.
config
[
"batchsize"
],
self
.
config
[
"num_channels_latents"
],
height
,
width
)
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
self
.
config
[
"batchsize"
],
height
//
2
,
width
//
2
,
self
.
device
,
self
.
dtype
)
latent_image_ids
=
self
.
_prepare_latent_image_ids
(
self
.
config
[
"batchsize"
],
height
//
2
,
width
//
2
,
self
.
run_
device
,
self
.
dtype
)
self
.
latents
=
latents
self
.
latent_image_ids
=
latent_image_ids
...
...
@@ -198,7 +198,7 @@ class QwenImageScheduler(BaseScheduler):
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
self
.
device
,
self
.
run_
device
,
sigmas
=
sigmas
,
mu
=
mu
,
)
...
...
@@ -213,7 +213,7 @@ class QwenImageScheduler(BaseScheduler):
def
prepare_guidance
(
self
):
# handle guidance
if
self
.
config
[
"guidance_embeds"
]:
guidance
=
torch
.
full
([
1
],
self
.
guidance_scale
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
guidance
=
torch
.
full
([
1
],
self
.
guidance_scale
,
device
=
self
.
run_
device
,
dtype
=
torch
.
float32
)
guidance
=
guidance
.
expand
(
self
.
latents
.
shape
[
0
])
else
:
guidance
=
None
...
...
@@ -223,7 +223,7 @@ class QwenImageScheduler(BaseScheduler):
if
self
.
config
[
"task"
]
==
"i2i"
:
self
.
generator
=
torch
.
Generator
().
manual_seed
(
input_info
.
seed
)
elif
self
.
config
[
"task"
]
==
"t2i"
:
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
input_info
.
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_
device
).
manual_seed
(
input_info
.
seed
)
self
.
prepare_latents
(
input_info
)
self
.
prepare_guidance
()
self
.
set_timesteps
()
...
...
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
4c0a9a0d
...
...
@@ -58,14 +58,14 @@ class EulerScheduler(WanScheduler):
)
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_
device
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
latent_shape
[
0
],
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
device
=
self
.
run_
device
,
generator
=
self
.
generator
,
)
if
self
.
config
[
"model_cls"
]
==
"wan2.2_audio"
:
...
...
@@ -77,7 +77,7 @@ class EulerScheduler(WanScheduler):
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
float32
)
timesteps
=
np
.
linspace
(
self
.
num_train_timesteps
,
0
,
self
.
infer_steps
+
1
,
dtype
=
np
.
float32
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
self
.
device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
self
.
run_
device
)
self
.
timesteps_ori
=
self
.
timesteps
.
clone
()
self
.
sigmas
=
self
.
timesteps_ori
/
self
.
num_train_timesteps
...
...
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
View file @
4c0a9a0d
...
...
@@ -20,7 +20,7 @@ class WanScheduler4ChangingResolution:
assert
len
(
config
[
"resolution_rate"
])
==
len
(
config
[
"changing_resolution_steps"
])
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_
device
).
manual_seed
(
seed
)
self
.
latents_list
=
[]
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
self
.
latents_list
.
append
(
...
...
@@ -30,7 +30,7 @@ class WanScheduler4ChangingResolution:
int
(
latent_shape
[
2
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
latent_shape
[
3
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
dtype
=
dtype
,
device
=
self
.
device
,
device
=
self
.
run_
device
,
generator
=
self
.
generator
,
)
)
...
...
@@ -43,7 +43,7 @@ class WanScheduler4ChangingResolution:
latent_shape
[
2
],
latent_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
device
=
self
.
run_
device
,
generator
=
self
.
generator
,
)
)
...
...
@@ -83,7 +83,7 @@ class WanScheduler4ChangingResolution:
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# 5. update timesteps using shift + self.changing_resolution_index + 1 更激进的去噪
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
+
self
.
changing_resolution_index
+
1
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_
device
,
shift
=
self
.
sample_shift
+
self
.
changing_resolution_index
+
1
)
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sigma
=
self
.
sigmas
[
self
.
step_index
]
...
...
lightx2v/models/schedulers/wan/scheduler.py
View file @
4c0a9a0d
...
...
@@ -10,11 +10,10 @@ from lightx2v.utils.utils import masks_like
class
WanScheduler
(
BaseScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
run_
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
infer_steps
=
self
.
config
[
"infer_steps"
]
self
.
target_video_length
=
self
.
config
[
"target_video_length"
]
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
shift
=
1
self
.
num_train_timesteps
=
1000
...
...
@@ -65,7 +64,7 @@ class WanScheduler(BaseScheduler):
self
.
sigma_min
=
self
.
sigmas
[
-
1
].
item
()
self
.
sigma_max
=
self
.
sigmas
[
0
].
item
()
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
run_
device
,
shift
=
self
.
sample_shift
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
],
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
...
...
@@ -93,14 +92,14 @@ class WanScheduler(BaseScheduler):
return
cos_sin
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
run_
device
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
latent_shape
[
0
],
latent_shape
[
1
],
latent_shape
[
2
],
latent_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
device
=
self
.
run_
device
,
generator
=
self
.
generator
,
)
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
...
...
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
View file @
4c0a9a0d
...
...
@@ -7,7 +7,7 @@ from lightx2v.utils.envs import *
class
WanSFScheduler
(
WanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
run_
device
=
torch
.
device
(
config
.
get
(
"run_device"
),
"cuda"
)
self
.
dtype
=
torch
.
bfloat16
self
.
num_frame_per_block
=
self
.
config
[
"sf_config"
][
"num_frame_per_block"
]
self
.
num_output_frames
=
self
.
config
[
"sf_config"
][
"num_output_frames"
]
...
...
@@ -27,20 +27,20 @@ class WanSFScheduler(WanScheduler):
self
.
context_noise
=
0
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
self
.
latents
=
torch
.
randn
(
latent_shape
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
self
.
latents
=
torch
.
randn
(
latent_shape
,
device
=
self
.
run_
device
,
dtype
=
self
.
dtype
)
timesteps
=
[]
for
frame_block_idx
,
current_num_frames
in
enumerate
(
self
.
all_num_frames
):
frame_steps
=
[]
for
step_index
,
current_timestep
in
enumerate
(
self
.
denoising_step_list
):
timestep
=
torch
.
ones
([
self
.
num_frame_per_block
],
device
=
self
.
device
,
dtype
=
torch
.
int64
)
*
current_timestep
timestep
=
torch
.
ones
([
self
.
num_frame_per_block
],
device
=
self
.
run_
device
,
dtype
=
torch
.
int64
)
*
current_timestep
frame_steps
.
append
(
timestep
)
timesteps
.
append
(
frame_steps
)
self
.
timesteps
=
timesteps
self
.
noise_pred
=
torch
.
zeros
(
latent_shape
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
self
.
noise_pred
=
torch
.
zeros
(
latent_shape
,
device
=
self
.
run_
device
,
dtype
=
self
.
dtype
)
sigma_start
=
self
.
sigma_min
+
(
self
.
sigma_max
-
self
.
sigma_min
)
*
self
.
denoising_strength
if
self
.
extra_one_step
:
...
...
@@ -52,10 +52,10 @@ class WanSFScheduler(WanScheduler):
self
.
sigmas_sf
=
self
.
sf_shift
*
self
.
sigmas_sf
/
(
1
+
(
self
.
sf_shift
-
1
)
*
self
.
sigmas_sf
)
if
self
.
reverse_sigmas
:
self
.
sigmas_sf
=
1
-
self
.
sigmas_sf
self
.
sigmas_sf
=
self
.
sigmas_sf
.
to
(
self
.
device
)
self
.
sigmas_sf
=
self
.
sigmas_sf
.
to
(
self
.
run_
device
)
self
.
timesteps_sf
=
self
.
sigmas_sf
*
self
.
num_train_timesteps
self
.
timesteps_sf
=
self
.
timesteps_sf
.
to
(
self
.
device
)
self
.
timesteps_sf
=
self
.
timesteps_sf
.
to
(
self
.
run_
device
)
self
.
stream_output
=
None
...
...
@@ -93,7 +93,7 @@ class WanSFScheduler(WanScheduler):
# add noise
if
self
.
step_index
<
self
.
infer_steps
-
1
:
timestep_next
=
self
.
timesteps
[
self
.
seg_index
][
self
.
step_index
+
1
]
*
torch
.
ones
(
self
.
num_frame_per_block
,
device
=
self
.
device
,
dtype
=
torch
.
long
)
timestep_next
=
self
.
timesteps
[
self
.
seg_index
][
self
.
step_index
+
1
]
*
torch
.
ones
(
self
.
num_frame_per_block
,
device
=
self
.
run_
device
,
dtype
=
torch
.
long
)
timestep_id_next
=
torch
.
argmin
((
self
.
timesteps_sf
.
unsqueeze
(
0
)
-
timestep_next
.
unsqueeze
(
1
)).
abs
(),
dim
=
1
)
sigma_next
=
self
.
sigmas_sf
[
timestep_id_next
].
reshape
(
-
1
,
1
,
1
,
1
)
noise_next
=
torch
.
randn_like
(
x0_pred
)
...
...
lightx2v/models/schedulers/wan/step_distill/scheduler.py
View file @
4c0a9a0d
...
...
@@ -19,7 +19,7 @@ class WanStepDistillScheduler(WanScheduler):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
self
.
prepare_latents
(
seed
,
latent_shape
,
dtype
=
torch
.
float32
)
self
.
set_denoising_timesteps
(
device
=
self
.
device
)
self
.
set_denoising_timesteps
(
device
=
self
.
run_
device
)
def
set_denoising_timesteps
(
self
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
sigma_start
=
self
.
sigma_min
+
(
self
.
sigma_max
-
self
.
sigma_min
)
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
4c0a9a0d
...
...
@@ -1330,9 +1330,9 @@ class WanVAE:
def
device_synchronize
(
self
,
):
if
"cuda"
in
str
(
self
.
device
):
if
"cuda"
in
str
(
self
.
run_
device
):
torch
.
cuda
.
synchronize
()
elif
"mlu"
in
str
(
self
.
device
):
elif
"mlu"
in
str
(
self
.
run_
device
):
torch
.
mlu
.
synchronize
()
elif
"npu"
in
str
(
self
.
device
):
elif
"npu"
in
str
(
self
.
run_
device
):
torch
.
npu
.
synchronize
()
scripts/seko_talk/run_seko_talk_06_offload_fp8_H100.sh
View file @
4c0a9a0d
#!/bin/bash
lightx2v_path
=
/
mtc/gushiqiao/llmc_workspace/lightx2v_latest2
/LightX2V
model_path
=
/
d
at
a/nvme0/gushiqiao/models/Lightx2v_models/seko-new
/SekoTalk-Distill-fp8/
lightx2v_path
=
/
path/to
/LightX2V
model_path
=
/
p
at
h/to
/SekoTalk-Distill-fp8/
export
CUDA_VISIBLE_DEVICES
=
0
...
...
scripts/seko_talk/run_seko_talk_25_mlu.sh
0 → 100755
View file @
4c0a9a0d
#!/bin/bash
lightx2v_path
=
/path/to/Lightx2v
model_path
=
/path/to/SekoTalk-Distill
export
MLU_VISIBLE_DEVICES
=
0
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
export
SENSITIVE_LAYER_DTYPE
=
None
python
-m
lightx2v.infer
\
--model_cls
seko_talk
\
--task
s2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/seko_talk/mlu/seko_talk_bf16.json
\
--prompt
"The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze."
\
--negative_prompt
色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走
\
--image_path
${
lightx2v_path
}
/assets/inputs/audio/seko_input.png
\
--audio_path
${
lightx2v_path
}
/assets/inputs/audio/seko_input.mp3
\
--save_result_path
${
lightx2v_path
}
/save_results/output_lightx2v_seko_talk.mp4
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment