Unverified Commit 19ac1216 authored by Watebear's avatar Watebear Committed by GitHub
Browse files

[feat]: support server of self-forcing & matrix-game2 (#533)

parent bcb74974
......@@ -3,8 +3,8 @@
"target_video_length": 150,
"num_output_frames": 150,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"target_height": 352,
"target_width": 640,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
......
......@@ -3,8 +3,8 @@
"target_video_length": 360,
"num_output_frames": 360,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"target_height": 352,
"target_width": 640,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
......
......@@ -3,8 +3,8 @@
"target_video_length": 150,
"num_output_frames": 150,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"target_height": 352,
"target_width": 640,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
......
......@@ -3,8 +3,8 @@
"target_video_length": 360,
"num_output_frames": 360,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"target_height": 352,
"target_width": 640,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
......
......@@ -3,8 +3,8 @@
"target_video_length": 150,
"num_output_frames": 150,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"target_height": 352,
"target_width": 640,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
......
......@@ -3,8 +3,8 @@
"target_video_length": 360,
"num_output_frames": 360,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"target_height": 352,
"target_width": 640,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
......
......@@ -23,6 +23,14 @@
"outputs": ["output_video"]
}
}
},
"self-forcing-dmd": {
"single_stage": {
"pipeline": {
"inputs": [],
"outputs": ["output_video"]
}
}
}
},
"i2v": {
......@@ -59,6 +67,30 @@
"outputs": ["output_video"]
}
}
},
"matrix-game2-gta-drive": {
"single_stage": {
"pipeline": {
"inputs": ["input_image"],
"outputs": ["output_video"]
}
}
},
"matrix-game2-universal": {
"single_stage": {
"pipeline": {
"inputs": ["input_image"],
"outputs": ["output_video"]
}
}
},
"matrix-game2-templerun": {
"single_stage": {
"pipeline": {
"inputs": ["input_image"],
"outputs": ["output_video"]
}
}
}
},
"s2v": {
......@@ -112,6 +144,7 @@
"subtask_running_timeouts": {
"t2v-wan2.1-1.3B-multi_stage-dit": 300,
"t2v-wan2.1-1.3B-single_stage-pipeline": 300,
"t2v-self-forcing-dmd-single_stage-pipeline": 300,
"i2v-wan2.1-14B-480P-multi_stage-dit": 600,
"i2v-wan2.1-14B-480P-single_stage-pipeline": 600,
"i2v-SekoTalk-Distill-single_stage-pipeline": 3600,
......
import os
import queue
import socket
import subprocess
import threading
import time
import traceback
import numpy as np
import torch
from loguru import logger
def pseudo_random(a, b):
x = str(time.time()).split(".")[1]
y = int(float("0." + x) * 1000000)
return a + (y % (b - a + 1))
class VideoRecorder:
def __init__(
self,
livestream_url: str,
fps: float = 16.0,
):
self.livestream_url = livestream_url
self.fps = fps
self.video_port = pseudo_random(32000, 40000)
self.ffmpeg_log_level = os.getenv("FFMPEG_LOG_LEVEL", "error")
logger.info(f"VideoRecorder video port: {self.video_port}, ffmpeg_log_level: {self.ffmpeg_log_level}")
self.width = None
self.height = None
self.stoppable_t = None
self.realtime = True
# ffmpeg process for video data and push to livestream
self.ffmpeg_process = None
# TCP connection objects
self.video_socket = None
self.video_conn = None
self.video_thread = None
# queue for send data to ffmpeg process
self.video_queue = queue.Queue()
def init_sockets(self):
# TCP socket for send and recv video data
self.video_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.video_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.video_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.video_socket.bind(("127.0.0.1", self.video_port))
self.video_socket.listen(1)
def video_worker(self):
try:
logger.info("Waiting for ffmpeg to connect to video socket...")
self.video_conn, _ = self.video_socket.accept()
logger.info(f"Video connection established from {self.video_conn.getpeername()}")
fail_time, max_fail_time = 0, 10
packet_secs = 1.0 / self.fps
while True:
try:
if self.video_queue is None:
break
data = self.video_queue.get()
if data is None:
logger.info("Video thread received stop signal")
break
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
for i in range(data.shape[0]):
t0 = time.time()
frame = (data[i] * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
try:
self.video_conn.send(frame.tobytes())
except (BrokenPipeError, OSError, ConnectionResetError) as e:
logger.info(f"Video connection closed, stopping worker: {type(e).__name__}")
return
if self.realtime:
time.sleep(max(0, packet_secs - (time.time() - t0)))
fail_time = 0
except (BrokenPipeError, OSError, ConnectionResetError):
logger.info("Video connection closed during queue processing")
break
except Exception:
logger.error(f"Send video data error: {traceback.format_exc()}")
fail_time += 1
if fail_time > max_fail_time:
logger.error(f"Video push worker thread failed {fail_time} times, stopping...")
break
except Exception:
logger.error(f"Video push worker thread error: {traceback.format_exc()}")
finally:
logger.info("Video push worker thread stopped")
def start_ffmpeg_process_local(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"ffmpeg",
"-fflags",
"nobuffer",
"-analyzeduration",
"0",
"-probesize",
"32",
"-flush_packets",
"1",
"-f",
"rawvideo",
"-pix_fmt",
"rgb24",
"-color_range",
"pc",
"-colorspace",
"rgb",
"-color_primaries",
"bt709",
"-color_trc",
"iec61966-2-1",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-b:v",
"4M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-f",
"mp4",
self.livestream_url,
"-y",
"-loglevel",
self.ffmpeg_log_level,
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start_ffmpeg_process_rtmp(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"ffmpeg",
"-f",
"rawvideo",
"-re",
"-pix_fmt",
"rgb24",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-b:v",
"2M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-f",
"flv",
self.livestream_url,
"-y",
"-loglevel",
self.ffmpeg_log_level,
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start_ffmpeg_process_whip(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"ffmpeg",
"-re",
"-fflags",
"nobuffer",
"-analyzeduration",
"0",
"-probesize",
"32",
"-flush_packets",
"1",
"-f",
"rawvideo",
"-re",
"-pix_fmt",
"rgb24",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-b:v",
"2M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-threads",
"1",
"-bf",
"0",
"-f",
"whip",
self.livestream_url,
"-y",
"-loglevel",
self.ffmpeg_log_level,
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start(self, width: int, height: int):
self.set_video_size(width, height)
duration = 1.0
self.pub_video(torch.zeros((int(self.fps * duration), height, width, 3), dtype=torch.float16))
time.sleep(duration)
def set_video_size(self, width: int, height: int):
if self.width is not None and self.height is not None:
assert self.width == width and self.height == height, "Video size already set"
return
self.width = width
self.height = height
self.init_sockets()
if self.livestream_url.startswith("rtmp://"):
self.start_ffmpeg_process_rtmp()
elif self.livestream_url.startswith("http"):
self.start_ffmpeg_process_whip()
else:
self.start_ffmpeg_process_local()
self.realtime = False
self.video_thread = threading.Thread(target=self.video_worker)
self.video_thread.start()
# Publish ComfyUI Image tensor to livestream
def pub_video(self, images: torch.Tensor):
N, height, width, C = images.shape
assert C == 3, "Input must be [N, H, W, C] with C=3"
logger.info(f"Publishing video [{N}x{width}x{height}]")
self.set_video_size(width, height)
self.video_queue.put(images)
logger.info(f"Published {N} frames")
self.stoppable_t = time.time() + N / self.fps + 3
def stop(self, wait=True):
if wait and self.stoppable_t:
t = self.stoppable_t - time.time()
if t > 0:
logger.warning(f"Waiting for {t} seconds to stop ...")
time.sleep(t)
self.stoppable_t = None
# Send stop signals to queues
if self.video_queue:
self.video_queue.put(None)
# Wait for threads to finish processing queued data (increased timeout)
queue_timeout = 30 # Increased from 5s to 30s to allow sufficient time for large video frames
if self.video_thread and self.video_thread.is_alive():
self.video_thread.join(timeout=queue_timeout)
if self.video_thread.is_alive():
logger.error(f"Video push thread did not stop after {queue_timeout}s")
# Shutdown connections to signal EOF to FFmpeg
# shutdown(SHUT_WR) will wait for send buffer to flush, no explicit sleep needed
if self.video_conn:
try:
self.video_conn.getpeername()
self.video_conn.shutdown(socket.SHUT_WR)
logger.info("Video connection shutdown initiated")
except OSError:
# Connection already closed, skip shutdown
pass
if self.ffmpeg_process:
is_local_file = not self.livestream_url.startswith(("rtmp://", "http"))
# Local MP4 files need time to write moov atom and finalize the container
timeout_seconds = 30 if is_local_file else 10
logger.info(f"Waiting for FFmpeg to finalize file (timeout={timeout_seconds}s, local_file={is_local_file})")
logger.info(f"FFmpeg output: {self.livestream_url}")
try:
returncode = self.ffmpeg_process.wait(timeout=timeout_seconds)
if returncode == 0:
logger.info(f"FFmpeg process exited successfully (exit code: {returncode})")
else:
logger.warning(f"FFmpeg process exited with non-zero code: {returncode}")
except subprocess.TimeoutExpired:
logger.warning(f"FFmpeg process did not exit within {timeout_seconds}s, sending SIGTERM...")
try:
self.ffmpeg_process.terminate() # SIGTERM
returncode = self.ffmpeg_process.wait(timeout=5)
logger.warning(f"FFmpeg process terminated with SIGTERM (exit code: {returncode})")
except subprocess.TimeoutExpired:
logger.error("FFmpeg process still running after SIGTERM, killing with SIGKILL...")
self.ffmpeg_process.kill()
self.ffmpeg_process.wait() # Wait for kill to complete
logger.error("FFmpeg process killed with SIGKILL")
finally:
self.ffmpeg_process = None
if self.video_conn:
try:
self.video_conn.close()
except Exception as e:
logger.debug(f"Error closing video connection: {e}")
finally:
self.video_conn = None
if self.video_socket:
try:
self.video_socket.close()
except Exception as e:
logger.debug(f"Error closing video socket: {e}")
finally:
self.video_socket = None
if self.video_queue:
while self.video_queue.qsize() > 0:
try:
self.video_queue.get_nowait()
except: # noqa
break
self.video_queue = None
logger.info("VideoRecorder stopped and resources cleaned up")
def __del__(self):
self.stop(wait=False)
def create_simple_video(frames=10, height=480, width=640):
video_data = []
for i in range(frames):
frame = np.zeros((height, width, 3), dtype=np.float32)
stripe_height = height // 8
colors = [
[1.0, 0.0, 0.0], # 红色
[0.0, 1.0, 0.0], # 绿色
[0.0, 0.0, 1.0], # 蓝色
[1.0, 1.0, 0.0], # 黄色
[1.0, 0.0, 1.0], # 洋红
[0.0, 1.0, 1.0], # 青色
[1.0, 1.0, 1.0], # 白色
[0.5, 0.5, 0.5], # 灰色
]
for j, color in enumerate(colors):
start_y = j * stripe_height
end_y = min((j + 1) * stripe_height, height)
frame[start_y:end_y, :] = color
offset = int((i / frames) * width)
frame = np.roll(frame, offset, axis=1)
frame = torch.tensor(frame, dtype=torch.float32)
video_data.append(frame)
return torch.stack(video_data, dim=0)
if __name__ == "__main__":
fps = 16
width = 640
height = 480
recorder = VideoRecorder(
# livestream_url="rtmp://localhost/live/test",
# livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
livestream_url="/path/to/output_video.mp4",
fps=fps,
)
secs = 10 # 10秒视频
interval = 1
for i in range(0, secs, interval):
logger.info(f"{i} / {secs} s")
num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width)
logger.info(f"images: {images.shape} {images.dtype} {images.min()} {images.max()}")
recorder.pub_video(images)
time.sleep(interval)
recorder.stop()
......@@ -296,7 +296,7 @@ async def shutdown(loop):
# align args like infer.py
def align_args(args):
args.seed = 42
args.sf_model_path = ""
args.sf_model_path = args.sf_model_path if args.sf_model_path else ""
args.use_prompt_enhancer = False
args.prompt = ""
args.negative_prompt = ""
......@@ -308,6 +308,7 @@ def align_args(args):
args.src_mask = None
args.save_result_path = ""
args.return_result_tensor = False
args.is_live = True
# =========================
......@@ -335,6 +336,7 @@ if __name__ == "__main__":
parser.add_argument("--metric_port", type=int, default=8001)
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--sf_model_path", type=str, default="")
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--server", type=str, default="http://127.0.0.1:8080")
......
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict
import torch
......@@ -35,6 +36,7 @@ class WanSFPreInferModuleOutput:
seq_lens: torch.Tensor
freqs: torch.Tensor
context: torch.Tensor
conditional_dict: Dict[str, Any] = field(default_factory=dict)
class WanSFPreInfer(WanPreInfer):
......
......@@ -33,7 +33,7 @@ class WanActionTransformerWeights(WeightModule):
if i in action_blocks:
block_list.append(WanTransformerActionBlock(i, self.task, self.mm_type, self.config, "blocks"))
else:
block_list.append(WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config, "blocks"))
block_list.append(WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config, False, "blocks"))
self.blocks = WeightModuleList(block_list)
self.add_module("blocks", self.blocks)
......@@ -82,6 +82,7 @@ class WanTransformerActionBlock(WeightModule):
task,
mm_type,
config,
False,
self.lazy_load,
self.lazy_load_file,
),
......@@ -109,6 +110,7 @@ class WanTransformerActionBlock(WeightModule):
task,
mm_type,
config,
False,
self.lazy_load,
self.lazy_load_file,
),
......
import os
import torch
from diffusers.utils import load_image
from diffusers.utils.loading_utils import load_image
from torchvision.transforms import v2
from lightx2v.models.input_encoders.hf.wan.matrix_game2.clip import CLIPModel
......@@ -272,6 +272,55 @@ class WanSFMtxg2Runner(WanSFRunner):
if stop == "n":
break
stop = "n"
gen_video_final = self.process_images_after_vae_decoder()
self.end_run()
return gen_video_final
@ProfilingContext4DebugL2("Run DiT")
def run_main_live(self, total_steps=None):
try:
self.init_video_recorder()
logger.info(f"init video_recorder: {self.video_recorder}")
rank, world_size = self.get_rank_and_world_size()
if rank == world_size - 1:
assert self.video_recorder is not None, "video_recorder is required for stream audio input for rank 2"
self.video_recorder.start(self.width, self.height)
if world_size > 1:
dist.barrier()
self.init_run()
if self.config.get("compile", False):
self.model.select_graph_for_compile(self.input_info)
stop = ""
while stop != "n":
for segment_idx in range(self.video_segment_num):
logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext4DebugL1(
f"segment end2end {segment_idx + 1}/{self.video_segment_num}",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration,
metrics_labels=["DefaultRunner"],
):
self.check_stop()
# 1. default do nothing
self.init_run_segment(segment_idx)
# 2. main inference loop
latents = self.run_segment(segment_idx=segment_idx)
# 3. vae decoder
self.gen_video = self.run_vae_decoder(latents)
# 4. default do nothing
self.end_run_segment(segment_idx)
# 5. stop or not
if self.config["streaming"]:
stop = input("Press `n` to stop generation: ").strip().lower()
if stop == "n":
break
stop = "n"
finally:
if hasattr(self.model, "inputs"):
self.end_run()
if self.video_recorder:
self.video_recorder.stop()
self.video_recorder = None
......@@ -3,15 +3,18 @@ import gc
import torch
from loguru import logger
from lightx2v.deploy.common.video_recorder import VideoRecorder
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.sf_model import WanSFModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.self_forcing.scheduler import WanSFScheduler
from lightx2v.models.video_encoders.hf.wan.vae_sf import WanSFVAE
from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import *
from lightx2v.utils.memory_profiler import peak_memory_decorator
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import vae_to_comfyui_image_inplace
@RUNNER_REGISTER("wan2.1_sf")
......@@ -19,6 +22,11 @@ class WanSFRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
self.vae_cls = WanSFVAE
self.is_live = config.get("is_live", False)
if self.is_live:
self.width = self.config["target_width"]
self.height = self.config["target_height"]
self.run_main = self.run_main_live
def load_transformer(self):
model = WanSFModel(
......@@ -61,14 +69,6 @@ class WanSFRunner(WanRunner):
def init_run(self):
super().init_run()
@ProfilingContext4DebugL1("End run segment")
def end_run_segment(self, segment_idx=None):
with ProfilingContext4DebugL1("step_pre_in_rerun"):
self.model.scheduler.step_pre(seg_index=segment_idx, step_index=self.model.scheduler.infer_steps - 1, is_rerun=True)
with ProfilingContext4DebugL1("🚀 infer_main_in_rerun"):
self.model.infer(self.inputs)
self.gen_video_final = torch.cat([self.gen_video_final, self.gen_video], dim=0) if self.gen_video_final is not None else self.gen_video
@peak_memory_decorator
def run_segment(self, segment_idx=0):
infer_steps = self.model.scheduler.infer_steps
......@@ -93,3 +93,83 @@ class WanSFRunner(WanRunner):
self.progress_callback((current_step / total_all_steps) * 100, 100)
return self.model.scheduler.stream_output
def get_rank_and_world_size(self):
rank = 0
world_size = 1
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
return rank, world_size
def init_video_recorder(self):
output_video_path = self.input_info.save_result_path
self.video_recorder = None
if isinstance(output_video_path, dict):
output_video_path = output_video_path["data"]
logger.info(f"init video_recorder with output_video_path: {output_video_path}")
rank, world_size = self.get_rank_and_world_size()
if output_video_path and rank == world_size - 1:
record_fps = self.config.get("target_fps", 16)
audio_sr = self.config.get("audio_sr", 16000)
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
record_fps = self.config["video_frame_interpolation"]["target_fps"]
self.video_recorder = VideoRecorder(
livestream_url=output_video_path,
fps=record_fps,
)
@ProfilingContext4DebugL1("End run segment")
def end_run_segment(self, segment_idx=None):
with ProfilingContext4DebugL1("step_pre_in_rerun"):
self.model.scheduler.step_pre(seg_index=segment_idx, step_index=self.model.scheduler.infer_steps - 1, is_rerun=True)
with ProfilingContext4DebugL1("🚀 infer_main_in_rerun"):
self.model.infer(self.inputs)
self.gen_video_final = torch.cat([self.gen_video_final, self.gen_video], dim=0) if self.gen_video_final is not None else self.gen_video
if self.is_live:
if self.video_recorder:
stream_video = vae_to_comfyui_image_inplace(self.gen_video)
self.video_recorder.pub_video(stream_video)
torch.cuda.empty_cache()
@ProfilingContext4DebugL2("Run DiT")
def run_main_live(self, total_steps=None):
try:
self.init_video_recorder()
logger.info(f"init video_recorder: {self.video_recorder}")
rank, world_size = self.get_rank_and_world_size()
if rank == world_size - 1:
assert self.video_recorder is not None, "video_recorder is required for stream audio input for rank 2"
self.video_recorder.start(self.width, self.height)
if world_size > 1:
dist.barrier()
self.init_run()
if self.config.get("compile", False):
self.model.select_graph_for_compile(self.input_info)
for segment_idx in range(self.video_segment_num):
logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext4DebugL1(
f"segment end2end {segment_idx + 1}/{self.video_segment_num}",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration,
metrics_labels=["DefaultRunner"],
):
self.check_stop()
# 1. default do nothing
self.init_run_segment(segment_idx)
# 2. main inference loop
latents = self.run_segment(segment_idx)
# 3. vae decoder
self.gen_video = self.run_vae_decoder(latents)
# 4. default do nothing
self.end_run_segment(segment_idx)
finally:
if hasattr(self.model, "inputs"):
self.end_run()
if self.video_recorder:
self.video_recorder.stop()
self.video_recorder = None
......@@ -7,7 +7,7 @@ from lightx2v.utils.envs import *
class WanSFScheduler(WanScheduler):
def __init__(self, config):
super().__init__(config)
self.run_device = torch.device(config.get("run_device"), "cuda")
self.run_device = torch.device(config.get("run_device", "cuda"))
self.dtype = torch.bfloat16
self.num_frame_per_block = self.config["sf_config"]["num_frame_per_block"]
self.num_output_frames = self.config["sf_config"]["num_output_frames"]
......
#!/bin/bash
# set path and first
lightx2v_path=path to Lightx2v
model_path=path to Skywork/Matrix-Game-2.0
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1_sf_mtxg2 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_gta_drive.json \
--prompt '' \
--image_path gta_drive/0003.png \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_gta_drive.mp4 \
--seed 42
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment