Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
19ac1216
Unverified
Commit
19ac1216
authored
Nov 28, 2025
by
Watebear
Committed by
GitHub
Nov 28, 2025
Browse files
[feat]: support server of self-forcing & matrix-game2 (#533)
parent
bcb74974
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
635 additions
and
25 deletions
+635
-25
configs/matrix_game2/matrix_game2_gta_drive.json
configs/matrix_game2/matrix_game2_gta_drive.json
+2
-2
configs/matrix_game2/matrix_game2_gta_drive_streaming.json
configs/matrix_game2/matrix_game2_gta_drive_streaming.json
+2
-2
configs/matrix_game2/matrix_game2_templerun.json
configs/matrix_game2/matrix_game2_templerun.json
+2
-2
configs/matrix_game2/matrix_game2_templerun_streaming.json
configs/matrix_game2/matrix_game2_templerun_streaming.json
+2
-2
configs/matrix_game2/matrix_game2_universal.json
configs/matrix_game2/matrix_game2_universal.json
+2
-2
configs/matrix_game2/matrix_game2_universal_streaming.json
configs/matrix_game2/matrix_game2_universal_streaming.json
+2
-2
configs/model_pipeline.json
configs/model_pipeline.json
+33
-0
lightx2v/deploy/common/video_recorder.py
lightx2v/deploy/common/video_recorder.py
+422
-0
lightx2v/deploy/worker/__main__.py
lightx2v/deploy/worker/__main__.py
+3
-1
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
+3
-1
lightx2v/models/networks/wan/weights/matrix_game2/transformer_weights.py
.../networks/wan/weights/matrix_game2/transformer_weights.py
+3
-1
lightx2v/models/runners/wan/wan_matrix_game2_runner.py
lightx2v/models/runners/wan/wan_matrix_game2_runner.py
+50
-1
lightx2v/models/runners/wan/wan_sf_runner.py
lightx2v/models/runners/wan/wan_sf_runner.py
+88
-8
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
+1
-1
test_cases/run_matrix_game2_gta_drive.sh
test_cases/run_matrix_game2_gta_drive.sh
+20
-0
No files found.
configs/matrix_game2/matrix_game2_gta_drive.json
View file @
19ac1216
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
"target_video_length"
:
150
,
"target_video_length"
:
150
,
"num_output_frames"
:
150
,
"num_output_frames"
:
150
,
"text_len"
:
512
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_height"
:
352
,
"target_width"
:
832
,
"target_width"
:
640
,
"self_attn_1_type"
:
"flash_attn2"
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
...
...
configs/matrix_game2/matrix_game2_gta_drive_streaming.json
View file @
19ac1216
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
"target_video_length"
:
360
,
"target_video_length"
:
360
,
"num_output_frames"
:
360
,
"num_output_frames"
:
360
,
"text_len"
:
512
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_height"
:
352
,
"target_width"
:
832
,
"target_width"
:
640
,
"self_attn_1_type"
:
"flash_attn2"
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
...
...
configs/matrix_game2/matrix_game2_templerun.json
View file @
19ac1216
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
"target_video_length"
:
150
,
"target_video_length"
:
150
,
"num_output_frames"
:
150
,
"num_output_frames"
:
150
,
"text_len"
:
512
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_height"
:
352
,
"target_width"
:
832
,
"target_width"
:
640
,
"self_attn_1_type"
:
"flash_attn2"
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
...
...
configs/matrix_game2/matrix_game2_templerun_streaming.json
View file @
19ac1216
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
"target_video_length"
:
360
,
"target_video_length"
:
360
,
"num_output_frames"
:
360
,
"num_output_frames"
:
360
,
"text_len"
:
512
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_height"
:
352
,
"target_width"
:
832
,
"target_width"
:
640
,
"self_attn_1_type"
:
"flash_attn2"
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
...
...
configs/matrix_game2/matrix_game2_universal.json
View file @
19ac1216
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
"target_video_length"
:
150
,
"target_video_length"
:
150
,
"num_output_frames"
:
150
,
"num_output_frames"
:
150
,
"text_len"
:
512
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_height"
:
352
,
"target_width"
:
832
,
"target_width"
:
640
,
"self_attn_1_type"
:
"flash_attn2"
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
...
...
configs/matrix_game2/matrix_game2_universal_streaming.json
View file @
19ac1216
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
"target_video_length"
:
360
,
"target_video_length"
:
360
,
"num_output_frames"
:
360
,
"num_output_frames"
:
360
,
"text_len"
:
512
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_height"
:
352
,
"target_width"
:
832
,
"target_width"
:
640
,
"self_attn_1_type"
:
"flash_attn2"
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
...
...
configs/model_pipeline.json
View file @
19ac1216
...
@@ -23,6 +23,14 @@
...
@@ -23,6 +23,14 @@
"outputs"
:
[
"output_video"
]
"outputs"
:
[
"output_video"
]
}
}
}
}
},
"self-forcing-dmd"
:
{
"single_stage"
:
{
"pipeline"
:
{
"inputs"
:
[],
"outputs"
:
[
"output_video"
]
}
}
}
}
},
},
"i2v"
:
{
"i2v"
:
{
...
@@ -59,6 +67,30 @@
...
@@ -59,6 +67,30 @@
"outputs"
:
[
"output_video"
]
"outputs"
:
[
"output_video"
]
}
}
}
}
},
"matrix-game2-gta-drive"
:
{
"single_stage"
:
{
"pipeline"
:
{
"inputs"
:
[
"input_image"
],
"outputs"
:
[
"output_video"
]
}
}
},
"matrix-game2-universal"
:
{
"single_stage"
:
{
"pipeline"
:
{
"inputs"
:
[
"input_image"
],
"outputs"
:
[
"output_video"
]
}
}
},
"matrix-game2-templerun"
:
{
"single_stage"
:
{
"pipeline"
:
{
"inputs"
:
[
"input_image"
],
"outputs"
:
[
"output_video"
]
}
}
}
}
},
},
"s2v"
:
{
"s2v"
:
{
...
@@ -112,6 +144,7 @@
...
@@ -112,6 +144,7 @@
"subtask_running_timeouts"
:
{
"subtask_running_timeouts"
:
{
"t2v-wan2.1-1.3B-multi_stage-dit"
:
300
,
"t2v-wan2.1-1.3B-multi_stage-dit"
:
300
,
"t2v-wan2.1-1.3B-single_stage-pipeline"
:
300
,
"t2v-wan2.1-1.3B-single_stage-pipeline"
:
300
,
"t2v-self-forcing-dmd-single_stage-pipeline"
:
300
,
"i2v-wan2.1-14B-480P-multi_stage-dit"
:
600
,
"i2v-wan2.1-14B-480P-multi_stage-dit"
:
600
,
"i2v-wan2.1-14B-480P-single_stage-pipeline"
:
600
,
"i2v-wan2.1-14B-480P-single_stage-pipeline"
:
600
,
"i2v-SekoTalk-Distill-single_stage-pipeline"
:
3600
,
"i2v-SekoTalk-Distill-single_stage-pipeline"
:
3600
,
...
...
lightx2v/deploy/common/video_recorder.py
0 → 100644
View file @
19ac1216
import
os
import
queue
import
socket
import
subprocess
import
threading
import
time
import
traceback
import
numpy
as
np
import
torch
from
loguru
import
logger
def
pseudo_random
(
a
,
b
):
x
=
str
(
time
.
time
()).
split
(
"."
)[
1
]
y
=
int
(
float
(
"0."
+
x
)
*
1000000
)
return
a
+
(
y
%
(
b
-
a
+
1
))
class
VideoRecorder
:
def
__init__
(
self
,
livestream_url
:
str
,
fps
:
float
=
16.0
,
):
self
.
livestream_url
=
livestream_url
self
.
fps
=
fps
self
.
video_port
=
pseudo_random
(
32000
,
40000
)
self
.
ffmpeg_log_level
=
os
.
getenv
(
"FFMPEG_LOG_LEVEL"
,
"error"
)
logger
.
info
(
f
"VideoRecorder video port:
{
self
.
video_port
}
, ffmpeg_log_level:
{
self
.
ffmpeg_log_level
}
"
)
self
.
width
=
None
self
.
height
=
None
self
.
stoppable_t
=
None
self
.
realtime
=
True
# ffmpeg process for video data and push to livestream
self
.
ffmpeg_process
=
None
# TCP connection objects
self
.
video_socket
=
None
self
.
video_conn
=
None
self
.
video_thread
=
None
# queue for send data to ffmpeg process
self
.
video_queue
=
queue
.
Queue
()
def
init_sockets
(
self
):
# TCP socket for send and recv video data
self
.
video_socket
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
video_socket
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
self
.
video_socket
.
setsockopt
(
socket
.
IPPROTO_TCP
,
socket
.
TCP_NODELAY
,
1
)
self
.
video_socket
.
bind
((
"127.0.0.1"
,
self
.
video_port
))
self
.
video_socket
.
listen
(
1
)
def
video_worker
(
self
):
try
:
logger
.
info
(
"Waiting for ffmpeg to connect to video socket..."
)
self
.
video_conn
,
_
=
self
.
video_socket
.
accept
()
logger
.
info
(
f
"Video connection established from
{
self
.
video_conn
.
getpeername
()
}
"
)
fail_time
,
max_fail_time
=
0
,
10
packet_secs
=
1.0
/
self
.
fps
while
True
:
try
:
if
self
.
video_queue
is
None
:
break
data
=
self
.
video_queue
.
get
()
if
data
is
None
:
logger
.
info
(
"Video thread received stop signal"
)
break
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
for
i
in
range
(
data
.
shape
[
0
]):
t0
=
time
.
time
()
frame
=
(
data
[
i
]
*
255
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
try
:
self
.
video_conn
.
send
(
frame
.
tobytes
())
except
(
BrokenPipeError
,
OSError
,
ConnectionResetError
)
as
e
:
logger
.
info
(
f
"Video connection closed, stopping worker:
{
type
(
e
).
__name__
}
"
)
return
if
self
.
realtime
:
time
.
sleep
(
max
(
0
,
packet_secs
-
(
time
.
time
()
-
t0
)))
fail_time
=
0
except
(
BrokenPipeError
,
OSError
,
ConnectionResetError
):
logger
.
info
(
"Video connection closed during queue processing"
)
break
except
Exception
:
logger
.
error
(
f
"Send video data error:
{
traceback
.
format_exc
()
}
"
)
fail_time
+=
1
if
fail_time
>
max_fail_time
:
logger
.
error
(
f
"Video push worker thread failed
{
fail_time
}
times, stopping..."
)
break
except
Exception
:
logger
.
error
(
f
"Video push worker thread error:
{
traceback
.
format_exc
()
}
"
)
finally
:
logger
.
info
(
"Video push worker thread stopped"
)
def
start_ffmpeg_process_local
(
self
):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd
=
[
"ffmpeg"
,
"-fflags"
,
"nobuffer"
,
"-analyzeduration"
,
"0"
,
"-probesize"
,
"32"
,
"-flush_packets"
,
"1"
,
"-f"
,
"rawvideo"
,
"-pix_fmt"
,
"rgb24"
,
"-color_range"
,
"pc"
,
"-colorspace"
,
"rgb"
,
"-color_primaries"
,
"bt709"
,
"-color_trc"
,
"iec61966-2-1"
,
"-r"
,
str
(
self
.
fps
),
"-s"
,
f
"
{
self
.
width
}
x
{
self
.
height
}
"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
video_port
}
"
,
"-b:v"
,
"4M"
,
"-c:v"
,
"libx264"
,
"-preset"
,
"ultrafast"
,
"-tune"
,
"zerolatency"
,
"-g"
,
f
"
{
self
.
fps
}
"
,
"-pix_fmt"
,
"yuv420p"
,
"-f"
,
"mp4"
,
self
.
livestream_url
,
"-y"
,
"-loglevel"
,
self
.
ffmpeg_log_level
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
)
logger
.
info
(
f
"FFmpeg streaming started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg:
{
e
}
"
)
def
start_ffmpeg_process_rtmp
(
self
):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd
=
[
"ffmpeg"
,
"-f"
,
"rawvideo"
,
"-re"
,
"-pix_fmt"
,
"rgb24"
,
"-r"
,
str
(
self
.
fps
),
"-s"
,
f
"
{
self
.
width
}
x
{
self
.
height
}
"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
video_port
}
"
,
"-b:v"
,
"2M"
,
"-c:v"
,
"libx264"
,
"-preset"
,
"ultrafast"
,
"-tune"
,
"zerolatency"
,
"-g"
,
f
"
{
self
.
fps
}
"
,
"-pix_fmt"
,
"yuv420p"
,
"-f"
,
"flv"
,
self
.
livestream_url
,
"-y"
,
"-loglevel"
,
self
.
ffmpeg_log_level
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
)
logger
.
info
(
f
"FFmpeg streaming started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg:
{
e
}
"
)
def
start_ffmpeg_process_whip
(
self
):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd
=
[
"ffmpeg"
,
"-re"
,
"-fflags"
,
"nobuffer"
,
"-analyzeduration"
,
"0"
,
"-probesize"
,
"32"
,
"-flush_packets"
,
"1"
,
"-f"
,
"rawvideo"
,
"-re"
,
"-pix_fmt"
,
"rgb24"
,
"-r"
,
str
(
self
.
fps
),
"-s"
,
f
"
{
self
.
width
}
x
{
self
.
height
}
"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
video_port
}
"
,
"-b:v"
,
"2M"
,
"-c:v"
,
"libx264"
,
"-preset"
,
"ultrafast"
,
"-tune"
,
"zerolatency"
,
"-g"
,
f
"
{
self
.
fps
}
"
,
"-pix_fmt"
,
"yuv420p"
,
"-threads"
,
"1"
,
"-bf"
,
"0"
,
"-f"
,
"whip"
,
self
.
livestream_url
,
"-y"
,
"-loglevel"
,
self
.
ffmpeg_log_level
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
)
logger
.
info
(
f
"FFmpeg streaming started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg:
{
e
}
"
)
def
start
(
self
,
width
:
int
,
height
:
int
):
self
.
set_video_size
(
width
,
height
)
duration
=
1.0
self
.
pub_video
(
torch
.
zeros
((
int
(
self
.
fps
*
duration
),
height
,
width
,
3
),
dtype
=
torch
.
float16
))
time
.
sleep
(
duration
)
def
set_video_size
(
self
,
width
:
int
,
height
:
int
):
if
self
.
width
is
not
None
and
self
.
height
is
not
None
:
assert
self
.
width
==
width
and
self
.
height
==
height
,
"Video size already set"
return
self
.
width
=
width
self
.
height
=
height
self
.
init_sockets
()
if
self
.
livestream_url
.
startswith
(
"rtmp://"
):
self
.
start_ffmpeg_process_rtmp
()
elif
self
.
livestream_url
.
startswith
(
"http"
):
self
.
start_ffmpeg_process_whip
()
else
:
self
.
start_ffmpeg_process_local
()
self
.
realtime
=
False
self
.
video_thread
=
threading
.
Thread
(
target
=
self
.
video_worker
)
self
.
video_thread
.
start
()
# Publish ComfyUI Image tensor to livestream
def
pub_video
(
self
,
images
:
torch
.
Tensor
):
N
,
height
,
width
,
C
=
images
.
shape
assert
C
==
3
,
"Input must be [N, H, W, C] with C=3"
logger
.
info
(
f
"Publishing video [
{
N
}
x
{
width
}
x
{
height
}
]"
)
self
.
set_video_size
(
width
,
height
)
self
.
video_queue
.
put
(
images
)
logger
.
info
(
f
"Published
{
N
}
frames"
)
self
.
stoppable_t
=
time
.
time
()
+
N
/
self
.
fps
+
3
def
stop
(
self
,
wait
=
True
):
if
wait
and
self
.
stoppable_t
:
t
=
self
.
stoppable_t
-
time
.
time
()
if
t
>
0
:
logger
.
warning
(
f
"Waiting for
{
t
}
seconds to stop ..."
)
time
.
sleep
(
t
)
self
.
stoppable_t
=
None
# Send stop signals to queues
if
self
.
video_queue
:
self
.
video_queue
.
put
(
None
)
# Wait for threads to finish processing queued data (increased timeout)
queue_timeout
=
30
# Increased from 5s to 30s to allow sufficient time for large video frames
if
self
.
video_thread
and
self
.
video_thread
.
is_alive
():
self
.
video_thread
.
join
(
timeout
=
queue_timeout
)
if
self
.
video_thread
.
is_alive
():
logger
.
error
(
f
"Video push thread did not stop after
{
queue_timeout
}
s"
)
# Shutdown connections to signal EOF to FFmpeg
# shutdown(SHUT_WR) will wait for send buffer to flush, no explicit sleep needed
if
self
.
video_conn
:
try
:
self
.
video_conn
.
getpeername
()
self
.
video_conn
.
shutdown
(
socket
.
SHUT_WR
)
logger
.
info
(
"Video connection shutdown initiated"
)
except
OSError
:
# Connection already closed, skip shutdown
pass
if
self
.
ffmpeg_process
:
is_local_file
=
not
self
.
livestream_url
.
startswith
((
"rtmp://"
,
"http"
))
# Local MP4 files need time to write moov atom and finalize the container
timeout_seconds
=
30
if
is_local_file
else
10
logger
.
info
(
f
"Waiting for FFmpeg to finalize file (timeout=
{
timeout_seconds
}
s, local_file=
{
is_local_file
}
)"
)
logger
.
info
(
f
"FFmpeg output:
{
self
.
livestream_url
}
"
)
try
:
returncode
=
self
.
ffmpeg_process
.
wait
(
timeout
=
timeout_seconds
)
if
returncode
==
0
:
logger
.
info
(
f
"FFmpeg process exited successfully (exit code:
{
returncode
}
)"
)
else
:
logger
.
warning
(
f
"FFmpeg process exited with non-zero code:
{
returncode
}
"
)
except
subprocess
.
TimeoutExpired
:
logger
.
warning
(
f
"FFmpeg process did not exit within
{
timeout_seconds
}
s, sending SIGTERM..."
)
try
:
self
.
ffmpeg_process
.
terminate
()
# SIGTERM
returncode
=
self
.
ffmpeg_process
.
wait
(
timeout
=
5
)
logger
.
warning
(
f
"FFmpeg process terminated with SIGTERM (exit code:
{
returncode
}
)"
)
except
subprocess
.
TimeoutExpired
:
logger
.
error
(
"FFmpeg process still running after SIGTERM, killing with SIGKILL..."
)
self
.
ffmpeg_process
.
kill
()
self
.
ffmpeg_process
.
wait
()
# Wait for kill to complete
logger
.
error
(
"FFmpeg process killed with SIGKILL"
)
finally
:
self
.
ffmpeg_process
=
None
if
self
.
video_conn
:
try
:
self
.
video_conn
.
close
()
except
Exception
as
e
:
logger
.
debug
(
f
"Error closing video connection:
{
e
}
"
)
finally
:
self
.
video_conn
=
None
if
self
.
video_socket
:
try
:
self
.
video_socket
.
close
()
except
Exception
as
e
:
logger
.
debug
(
f
"Error closing video socket:
{
e
}
"
)
finally
:
self
.
video_socket
=
None
if
self
.
video_queue
:
while
self
.
video_queue
.
qsize
()
>
0
:
try
:
self
.
video_queue
.
get_nowait
()
except
:
# noqa
break
self
.
video_queue
=
None
logger
.
info
(
"VideoRecorder stopped and resources cleaned up"
)
def
__del__
(
self
):
self
.
stop
(
wait
=
False
)
def
create_simple_video
(
frames
=
10
,
height
=
480
,
width
=
640
):
video_data
=
[]
for
i
in
range
(
frames
):
frame
=
np
.
zeros
((
height
,
width
,
3
),
dtype
=
np
.
float32
)
stripe_height
=
height
//
8
colors
=
[
[
1.0
,
0.0
,
0.0
],
# 红色
[
0.0
,
1.0
,
0.0
],
# 绿色
[
0.0
,
0.0
,
1.0
],
# 蓝色
[
1.0
,
1.0
,
0.0
],
# 黄色
[
1.0
,
0.0
,
1.0
],
# 洋红
[
0.0
,
1.0
,
1.0
],
# 青色
[
1.0
,
1.0
,
1.0
],
# 白色
[
0.5
,
0.5
,
0.5
],
# 灰色
]
for
j
,
color
in
enumerate
(
colors
):
start_y
=
j
*
stripe_height
end_y
=
min
((
j
+
1
)
*
stripe_height
,
height
)
frame
[
start_y
:
end_y
,
:]
=
color
offset
=
int
((
i
/
frames
)
*
width
)
frame
=
np
.
roll
(
frame
,
offset
,
axis
=
1
)
frame
=
torch
.
tensor
(
frame
,
dtype
=
torch
.
float32
)
video_data
.
append
(
frame
)
return
torch
.
stack
(
video_data
,
dim
=
0
)
if
__name__
==
"__main__"
:
fps
=
16
width
=
640
height
=
480
recorder
=
VideoRecorder
(
# livestream_url="rtmp://localhost/live/test",
# livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
livestream_url
=
"/path/to/output_video.mp4"
,
fps
=
fps
,
)
secs
=
10
# 10秒视频
interval
=
1
for
i
in
range
(
0
,
secs
,
interval
):
logger
.
info
(
f
"
{
i
}
/
{
secs
}
s"
)
num_frames
=
int
(
interval
*
fps
)
images
=
create_simple_video
(
num_frames
,
height
,
width
)
logger
.
info
(
f
"images:
{
images
.
shape
}
{
images
.
dtype
}
{
images
.
min
()
}
{
images
.
max
()
}
"
)
recorder
.
pub_video
(
images
)
time
.
sleep
(
interval
)
recorder
.
stop
()
lightx2v/deploy/worker/__main__.py
View file @
19ac1216
...
@@ -296,7 +296,7 @@ async def shutdown(loop):
...
@@ -296,7 +296,7 @@ async def shutdown(loop):
# align args like infer.py
# align args like infer.py
def
align_args
(
args
):
def
align_args
(
args
):
args
.
seed
=
42
args
.
seed
=
42
args
.
sf_model_path
=
""
args
.
sf_model_path
=
args
.
sf_model_path
if
args
.
sf_model_path
else
""
args
.
use_prompt_enhancer
=
False
args
.
use_prompt_enhancer
=
False
args
.
prompt
=
""
args
.
prompt
=
""
args
.
negative_prompt
=
""
args
.
negative_prompt
=
""
...
@@ -308,6 +308,7 @@ def align_args(args):
...
@@ -308,6 +308,7 @@ def align_args(args):
args
.
src_mask
=
None
args
.
src_mask
=
None
args
.
save_result_path
=
""
args
.
save_result_path
=
""
args
.
return_result_tensor
=
False
args
.
return_result_tensor
=
False
args
.
is_live
=
True
# =========================
# =========================
...
@@ -335,6 +336,7 @@ if __name__ == "__main__":
...
@@ -335,6 +336,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--metric_port"
,
type
=
int
,
default
=
8001
)
parser
.
add_argument
(
"--metric_port"
,
type
=
int
,
default
=
8001
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--sf_model_path"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--server"
,
type
=
str
,
default
=
"http://127.0.0.1:8080"
)
parser
.
add_argument
(
"--server"
,
type
=
str
,
default
=
"http://127.0.0.1:8080"
)
...
...
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
View file @
19ac1216
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Dict
import
torch
import
torch
...
@@ -35,6 +36,7 @@ class WanSFPreInferModuleOutput:
...
@@ -35,6 +36,7 @@ class WanSFPreInferModuleOutput:
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
freqs
:
torch
.
Tensor
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
context
:
torch
.
Tensor
conditional_dict
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
class
WanSFPreInfer
(
WanPreInfer
):
class
WanSFPreInfer
(
WanPreInfer
):
...
...
lightx2v/models/networks/wan/weights/matrix_game2/transformer_weights.py
View file @
19ac1216
...
@@ -33,7 +33,7 @@ class WanActionTransformerWeights(WeightModule):
...
@@ -33,7 +33,7 @@ class WanActionTransformerWeights(WeightModule):
if
i
in
action_blocks
:
if
i
in
action_blocks
:
block_list
.
append
(
WanTransformerActionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
"blocks"
))
block_list
.
append
(
WanTransformerActionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
"blocks"
))
else
:
else
:
block_list
.
append
(
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
"blocks"
))
block_list
.
append
(
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
False
,
"blocks"
))
self
.
blocks
=
WeightModuleList
(
block_list
)
self
.
blocks
=
WeightModuleList
(
block_list
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
...
@@ -82,6 +82,7 @@ class WanTransformerActionBlock(WeightModule):
...
@@ -82,6 +82,7 @@ class WanTransformerActionBlock(WeightModule):
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
False
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -109,6 +110,7 @@ class WanTransformerActionBlock(WeightModule):
...
@@ -109,6 +110,7 @@ class WanTransformerActionBlock(WeightModule):
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
False
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
...
lightx2v/models/runners/wan/wan_matrix_game2_runner.py
View file @
19ac1216
import
os
import
os
import
torch
import
torch
from
diffusers.utils
import
load_image
from
diffusers.utils
.loading_utils
import
load_image
from
torchvision.transforms
import
v2
from
torchvision.transforms
import
v2
from
lightx2v.models.input_encoders.hf.wan.matrix_game2.clip
import
CLIPModel
from
lightx2v.models.input_encoders.hf.wan.matrix_game2.clip
import
CLIPModel
...
@@ -272,6 +272,55 @@ class WanSFMtxg2Runner(WanSFRunner):
...
@@ -272,6 +272,55 @@ class WanSFMtxg2Runner(WanSFRunner):
if
stop
==
"n"
:
if
stop
==
"n"
:
break
break
stop
=
"n"
stop
=
"n"
gen_video_final
=
self
.
process_images_after_vae_decoder
()
gen_video_final
=
self
.
process_images_after_vae_decoder
()
self
.
end_run
()
self
.
end_run
()
return
gen_video_final
return
gen_video_final
@
ProfilingContext4DebugL2
(
"Run DiT"
)
def
run_main_live
(
self
,
total_steps
=
None
):
try
:
self
.
init_video_recorder
()
logger
.
info
(
f
"init video_recorder:
{
self
.
video_recorder
}
"
)
rank
,
world_size
=
self
.
get_rank_and_world_size
()
if
rank
==
world_size
-
1
:
assert
self
.
video_recorder
is
not
None
,
"video_recorder is required for stream audio input for rank 2"
self
.
video_recorder
.
start
(
self
.
width
,
self
.
height
)
if
world_size
>
1
:
dist
.
barrier
()
self
.
init_run
()
if
self
.
config
.
get
(
"compile"
,
False
):
self
.
model
.
select_graph_for_compile
(
self
.
input_info
)
stop
=
""
while
stop
!=
"n"
:
for
segment_idx
in
range
(
self
.
video_segment_num
):
logger
.
info
(
f
"🔄 start segment
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
)
with
ProfilingContext4DebugL1
(
f
"segment end2end
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
,
recorder_mode
=
GET_RECORDER_MODE
(),
metrics_func
=
monitor_cli
.
lightx2v_run_segments_end2end_duration
,
metrics_labels
=
[
"DefaultRunner"
],
):
self
.
check_stop
()
# 1. default do nothing
self
.
init_run_segment
(
segment_idx
)
# 2. main inference loop
latents
=
self
.
run_segment
(
segment_idx
=
segment_idx
)
# 3. vae decoder
self
.
gen_video
=
self
.
run_vae_decoder
(
latents
)
# 4. default do nothing
self
.
end_run_segment
(
segment_idx
)
# 5. stop or not
if
self
.
config
[
"streaming"
]:
stop
=
input
(
"Press `n` to stop generation: "
).
strip
().
lower
()
if
stop
==
"n"
:
break
stop
=
"n"
finally
:
if
hasattr
(
self
.
model
,
"inputs"
):
self
.
end_run
()
if
self
.
video_recorder
:
self
.
video_recorder
.
stop
()
self
.
video_recorder
=
None
lightx2v/models/runners/wan/wan_sf_runner.py
View file @
19ac1216
...
@@ -3,15 +3,18 @@ import gc
...
@@ -3,15 +3,18 @@ import gc
import
torch
import
torch
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.deploy.common.video_recorder
import
VideoRecorder
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.sf_model
import
WanSFModel
from
lightx2v.models.networks.wan.sf_model
import
WanSFModel
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.schedulers.wan.self_forcing.scheduler
import
WanSFScheduler
from
lightx2v.models.schedulers.wan.self_forcing.scheduler
import
WanSFScheduler
from
lightx2v.models.video_encoders.hf.wan.vae_sf
import
WanSFVAE
from
lightx2v.models.video_encoders.hf.wan.vae_sf
import
WanSFVAE
from
lightx2v.server.metrics
import
monitor_cli
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.memory_profiler
import
peak_memory_decorator
from
lightx2v.utils.memory_profiler
import
peak_memory_decorator
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
vae_to_comfyui_image_inplace
@
RUNNER_REGISTER
(
"wan2.1_sf"
)
@
RUNNER_REGISTER
(
"wan2.1_sf"
)
...
@@ -19,6 +22,11 @@ class WanSFRunner(WanRunner):
...
@@ -19,6 +22,11 @@ class WanSFRunner(WanRunner):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
vae_cls
=
WanSFVAE
self
.
vae_cls
=
WanSFVAE
self
.
is_live
=
config
.
get
(
"is_live"
,
False
)
if
self
.
is_live
:
self
.
width
=
self
.
config
[
"target_width"
]
self
.
height
=
self
.
config
[
"target_height"
]
self
.
run_main
=
self
.
run_main_live
def
load_transformer
(
self
):
def
load_transformer
(
self
):
model
=
WanSFModel
(
model
=
WanSFModel
(
...
@@ -61,14 +69,6 @@ class WanSFRunner(WanRunner):
...
@@ -61,14 +69,6 @@ class WanSFRunner(WanRunner):
def
init_run
(
self
):
def
init_run
(
self
):
super
().
init_run
()
super
().
init_run
()
@
ProfilingContext4DebugL1
(
"End run segment"
)
def
end_run_segment
(
self
,
segment_idx
=
None
):
with
ProfilingContext4DebugL1
(
"step_pre_in_rerun"
):
self
.
model
.
scheduler
.
step_pre
(
seg_index
=
segment_idx
,
step_index
=
self
.
model
.
scheduler
.
infer_steps
-
1
,
is_rerun
=
True
)
with
ProfilingContext4DebugL1
(
"🚀 infer_main_in_rerun"
):
self
.
model
.
infer
(
self
.
inputs
)
self
.
gen_video_final
=
torch
.
cat
([
self
.
gen_video_final
,
self
.
gen_video
],
dim
=
0
)
if
self
.
gen_video_final
is
not
None
else
self
.
gen_video
@
peak_memory_decorator
@
peak_memory_decorator
def
run_segment
(
self
,
segment_idx
=
0
):
def
run_segment
(
self
,
segment_idx
=
0
):
infer_steps
=
self
.
model
.
scheduler
.
infer_steps
infer_steps
=
self
.
model
.
scheduler
.
infer_steps
...
@@ -93,3 +93,83 @@ class WanSFRunner(WanRunner):
...
@@ -93,3 +93,83 @@ class WanSFRunner(WanRunner):
self
.
progress_callback
((
current_step
/
total_all_steps
)
*
100
,
100
)
self
.
progress_callback
((
current_step
/
total_all_steps
)
*
100
,
100
)
return
self
.
model
.
scheduler
.
stream_output
return
self
.
model
.
scheduler
.
stream_output
def
get_rank_and_world_size
(
self
):
rank
=
0
world_size
=
1
if
dist
.
is_initialized
():
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
return
rank
,
world_size
def
init_video_recorder
(
self
):
output_video_path
=
self
.
input_info
.
save_result_path
self
.
video_recorder
=
None
if
isinstance
(
output_video_path
,
dict
):
output_video_path
=
output_video_path
[
"data"
]
logger
.
info
(
f
"init video_recorder with output_video_path:
{
output_video_path
}
"
)
rank
,
world_size
=
self
.
get_rank_and_world_size
()
if
output_video_path
and
rank
==
world_size
-
1
:
record_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
audio_sr
=
self
.
config
.
get
(
"audio_sr"
,
16000
)
if
"video_frame_interpolation"
in
self
.
config
and
self
.
vfi_model
is
not
None
:
record_fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
self
.
video_recorder
=
VideoRecorder
(
livestream_url
=
output_video_path
,
fps
=
record_fps
,
)
@
ProfilingContext4DebugL1
(
"End run segment"
)
def
end_run_segment
(
self
,
segment_idx
=
None
):
with
ProfilingContext4DebugL1
(
"step_pre_in_rerun"
):
self
.
model
.
scheduler
.
step_pre
(
seg_index
=
segment_idx
,
step_index
=
self
.
model
.
scheduler
.
infer_steps
-
1
,
is_rerun
=
True
)
with
ProfilingContext4DebugL1
(
"🚀 infer_main_in_rerun"
):
self
.
model
.
infer
(
self
.
inputs
)
self
.
gen_video_final
=
torch
.
cat
([
self
.
gen_video_final
,
self
.
gen_video
],
dim
=
0
)
if
self
.
gen_video_final
is
not
None
else
self
.
gen_video
if
self
.
is_live
:
if
self
.
video_recorder
:
stream_video
=
vae_to_comfyui_image_inplace
(
self
.
gen_video
)
self
.
video_recorder
.
pub_video
(
stream_video
)
torch
.
cuda
.
empty_cache
()
@
ProfilingContext4DebugL2
(
"Run DiT"
)
def
run_main_live
(
self
,
total_steps
=
None
):
try
:
self
.
init_video_recorder
()
logger
.
info
(
f
"init video_recorder:
{
self
.
video_recorder
}
"
)
rank
,
world_size
=
self
.
get_rank_and_world_size
()
if
rank
==
world_size
-
1
:
assert
self
.
video_recorder
is
not
None
,
"video_recorder is required for stream audio input for rank 2"
self
.
video_recorder
.
start
(
self
.
width
,
self
.
height
)
if
world_size
>
1
:
dist
.
barrier
()
self
.
init_run
()
if
self
.
config
.
get
(
"compile"
,
False
):
self
.
model
.
select_graph_for_compile
(
self
.
input_info
)
for
segment_idx
in
range
(
self
.
video_segment_num
):
logger
.
info
(
f
"🔄 start segment
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
)
with
ProfilingContext4DebugL1
(
f
"segment end2end
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
,
recorder_mode
=
GET_RECORDER_MODE
(),
metrics_func
=
monitor_cli
.
lightx2v_run_segments_end2end_duration
,
metrics_labels
=
[
"DefaultRunner"
],
):
self
.
check_stop
()
# 1. default do nothing
self
.
init_run_segment
(
segment_idx
)
# 2. main inference loop
latents
=
self
.
run_segment
(
segment_idx
)
# 3. vae decoder
self
.
gen_video
=
self
.
run_vae_decoder
(
latents
)
# 4. default do nothing
self
.
end_run_segment
(
segment_idx
)
finally
:
if
hasattr
(
self
.
model
,
"inputs"
):
self
.
end_run
()
if
self
.
video_recorder
:
self
.
video_recorder
.
stop
()
self
.
video_recorder
=
None
lightx2v/models/schedulers/wan/self_forcing/scheduler.py
View file @
19ac1216
...
@@ -7,7 +7,7 @@ from lightx2v.utils.envs import *
...
@@ -7,7 +7,7 @@ from lightx2v.utils.envs import *
class
WanSFScheduler
(
WanScheduler
):
class
WanSFScheduler
(
WanScheduler
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
run_device
=
torch
.
device
(
config
.
get
(
"run_device"
)
,
"cuda"
)
self
.
run_device
=
torch
.
device
(
config
.
get
(
"run_device"
,
"cuda"
)
)
self
.
dtype
=
torch
.
bfloat16
self
.
dtype
=
torch
.
bfloat16
self
.
num_frame_per_block
=
self
.
config
[
"sf_config"
][
"num_frame_per_block"
]
self
.
num_frame_per_block
=
self
.
config
[
"sf_config"
][
"num_frame_per_block"
]
self
.
num_output_frames
=
self
.
config
[
"sf_config"
][
"num_output_frames"
]
self
.
num_output_frames
=
self
.
config
[
"sf_config"
][
"num_output_frames"
]
...
...
test_cases/run_matrix_game2_gta_drive.sh
0 → 100644
View file @
19ac1216
#!/bin/bash
# set path and first
lightx2v_path
=
path to Lightx2v
model_path
=
path to Skywork/Matrix-Game-2.0
export
CUDA_VISIBLE_DEVICES
=
0
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
python
-m
lightx2v.infer
\
--model_cls
wan2.1_sf_mtxg2
\
--task
i2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/matrix_game2/matrix_game2_gta_drive.json
\
--prompt
''
\
--image_path
gta_drive/0003.png
\
--save_result_path
${
lightx2v_path
}
/save_results/output_lightx2v_matrix_game2_gta_drive.mp4
\
--seed
42
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment