Commit 99a6f046 authored by wangshankun's avatar wangshankun
Browse files

Merge branch 'main' of https://github.com/ModelTC/LightX2V into main

parents 8bdefedf 068a47db
#!/usr/bin/env python
import argparse import argparse
import atexit
import signal
import sys import sys
from pathlib import Path from pathlib import Path
import uvicorn sys.path.insert(0, str(Path(__file__).parent.parent))
from loguru import logger
from lightx2v.server.api import ApiServer from lightx2v.server.main import run_server
from lightx2v.server.service import DistributedInferenceService
def create_signal_handler(inference_service: DistributedInferenceService):
"""Create unified signal handler function"""
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, gracefully shutting down...")
try:
if inference_service.is_running:
inference_service.stop_distributed_inference()
except Exception as e:
logger.error(f"Error occurred while shutting down distributed inference service: {str(e)}")
finally:
sys.exit(0)
return signal_handler
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(description="Run LightX2V inference server")
parser.add_argument(
"--model_cls",
type=str,
required=True,
choices=[
"wan2.1",
"hunyuan",
"wan2.1_distill",
"wan2.1_causvid",
"wan2.1_skyreels_v2_df",
"wan2.1_audio",
"wan2.2_moe",
"wan2.2_moe_distill",
],
default="wan2.1",
)
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--split", action="store_true") parser.add_argument("--model_path", type=str, required=True, help="Path to model")
parser.add_argument("--lora_path", type=str, required=False, default=None) parser.add_argument("--model_cls", type=str, required=True, help="Model class name")
parser.add_argument("--lora_strength", type=float, default=1.0, help="The strength for the lora (default: 1.0)") parser.add_argument("--config_json", type=str, help="Path to model config JSON file")
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--task", type=str, default="i2v", help="Task type (i2v, etc.)")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node for distributed inference")
args = parser.parse_args() parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node (GPUs to use)")
logger.info(f"args: {args}")
cache_dir = Path(__file__).parent.parent / "server_cache" parser.add_argument("--port", type=int, default=8000, help="Server port")
inference_service = DistributedInferenceService() parser.add_argument("--host", type=str, default="127.0.0.1", help="Server host")
api_server = ApiServer() args = parser.parse_args()
api_server.initialize_services(cache_dir, inference_service)
signal_handler = create_signal_handler(inference_service)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
logger.info("Starting distributed inference service...")
success = inference_service.start_distributed_inference(args)
if not success:
logger.error("Failed to start distributed inference service, exiting program")
sys.exit(1)
atexit.register(inference_service.stop_distributed_inference)
try: run_server(args)
logger.info(f"Starting FastAPI server on port: {args.port}")
uvicorn.run(
api_server.get_app(),
host="0.0.0.0",
port=args.port,
reload=False,
workers=1,
)
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down service...")
except Exception as e:
logger.error(f"Error occurred while running FastAPI server: {str(e)}")
finally:
inference_service.stop_distributed_inference()
if __name__ == "__main__": if __name__ == "__main__":
......
import glob import glob
import os import os
import torch
from lightx2v.common.ops.attn.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.model import WanModel
...@@ -27,87 +24,6 @@ class WanAudioModel(WanModel): ...@@ -27,87 +24,6 @@ class WanAudioModel(WanModel):
self.pre_infer_class = WanAudioPreInfer self.pre_infer_class = WanAudioPreInfer
self.post_infer_class = WanAudioPostInfer self.post_infer_class = WanAudioPostInfer
@torch.no_grad()
def infer(self, inputs):
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
num_frame = c + 1 # for r2v
video_token_num = num_frame * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, num_frame)
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_cond
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
@torch.no_grad()
def infer_wo_cfg_parallel(self, inputs):
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
num_frame = c + 1 # for r2v
video_token_num = num_frame * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, num_frame)
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
self.scheduler.noise_pred = noise_pred_cond
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_cond, grid_sizes
torch.cuda.empty_cache()
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
torch.cuda.empty_cache()
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
class Wan22MoeAudioModel(WanAudioModel): class Wan22MoeAudioModel(WanAudioModel):
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
......
...@@ -18,30 +18,11 @@ class WanAudioPostInfer(WanPostInfer): ...@@ -18,30 +18,11 @@ class WanAudioPostInfer(WanPostInfer):
self.scheduler = scheduler self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, x, e, grid_sizes, valid_patch_length): def infer(self, weights, x, pre_infer_out):
if e.dim() == 2: x = x[:, : pre_infer_out.valid_patch_length]
modulation = weights.head_modulation.tensor # 1, 2, dim x = self.unpatchify(x, pre_infer_out.grid_sizes)
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
elif e.dim() == 3: # For Diffustion forcing
modulation = weights.head_modulation.tensor.unsqueeze(2) # 1, 2, seq, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e]
x = weights.norm.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype)
x.mul_(1 + e[1].squeeze()).add_(e[0].squeeze())
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.infer_dtype)
x = weights.head.apply(x)
x = x[:, :valid_patch_length]
x = self.unpatchify(x, grid_sizes)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del e, grid_sizes
torch.cuda.empty_cache() torch.cuda.empty_cache()
return [u.float() for u in x] return [u.float() for u in x]
......
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from ..module_io import WanPreInferModuleOutput
from ..utils import rope_params, sinusoidal_embedding_1d, masks_like from ..utils import rope_params, sinusoidal_embedding_1d, masks_like
from loguru import logger from loguru import logger
...@@ -133,4 +134,14 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -133,4 +134,14 @@ class WanAudioPreInfer(WanPreInfer):
del context_clip del context_clip
torch.cuda.empty_cache() torch.cuda.empty_cache()
return (embed, x_grid_sizes, (x.squeeze(0), embed0.squeeze(0), seq_lens, self.freqs, context, audio_dit_blocks), valid_patch_length) return WanPreInferModuleOutput(
embed=embed,
grid_sizes=x_grid_sizes,
x=x.squeeze(0),
embed0=embed0.squeeze(0),
seq_lens=seq_lens,
freqs=self.freqs,
context=context,
audio_dit_blocks=audio_dit_blocks,
valid_patch_length=valid_patch_length,
)
import torch
import torch.distributed as dist
import torch.nn.functional as F
from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer
from lightx2v.models.networks.wan.infer.utils import pad_freqs
class WanTransformerDistInfer(WanTransformerInfer):
def __init__(self, config, seq_p_group=None):
super().__init__(config)
self.seq_p_group = seq_p_group
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
x, embed0 = self.dist_pre_process(x, embed0)
x = super().infer(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks)
x = self.dist_post_process(x)
return x
def compute_freqs(self, q, grid_sizes, freqs):
if "audio" in self.config.get("model_cls", ""):
freqs_i = self.compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = self.compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
return freqs_i
def dist_pre_process(self, x, embed0):
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
padding_size = (world_size - (x.shape[0] % world_size)) % world_size
if padding_size > 0:
# 使用 F.pad 填充第一维
x = F.pad(x, (0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
x = torch.chunk(x, world_size, dim=0)[cur_rank]
if self.config["model_cls"].startswith("wan2.2"):
embed0 = torch.chunk(embed0, world_size, dim=0)[cur_rank]
return x, embed0
def dist_post_process(self, x):
world_size = dist.get_world_size(self.seq_p_group)
# 创建一个列表,用于存储所有进程的输出
gathered_x = [torch.empty_like(x) for _ in range(world_size)]
# 收集所有进程的输出
dist.all_gather(gathered_x, x, group=self.seq_p_group)
# 在指定的维度上合并所有进程的输出
combined_output = torch.cat(gathered_x, dim=0)
return combined_output # 返回合并后的输出
def compute_freqs_dist(self, s, c, grid_sizes, freqs):
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def compute_freqs_audio_dist(self, s, c, grid_sizes, freqs):
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
valid_token_length = f * h * w
f = f + 1
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i[valid_token_length:, :, :f] = 0
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
from dataclasses import dataclass
from typing import List
import torch
@dataclass
class WanPreInferModuleOutput:
embed: torch.Tensor
grid_sizes: torch.Tensor
x: torch.Tensor
embed0: torch.Tensor
seq_lens: torch.Tensor
freqs: torch.Tensor
context: torch.Tensor
audio_dit_blocks: List = None
valid_patch_length: int = None
...@@ -10,35 +10,15 @@ class WanPostInfer: ...@@ -10,35 +10,15 @@ class WanPostInfer:
self.out_dim = config["out_dim"] self.out_dim = config["out_dim"]
self.patch_size = (1, 2, 2) self.patch_size = (1, 2, 2)
self.clean_cuda_cache = config.get("clean_cuda_cache", False) self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, x, e, grid_sizes): def infer(self, weights, x, pre_infer_out):
if e.dim() == 2: x = self.unpatchify(x, pre_infer_out.grid_sizes)
modulation = weights.head_modulation.tensor # 1, 2, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
elif e.dim() == 3: # For Diffustion forcing
modulation = weights.head_modulation.tensor.unsqueeze(2) # 1, 2, seq, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e]
x = weights.norm.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype)
x.mul_(1 + e[1].squeeze()).add_(e[0].squeeze())
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.infer_dtype)
x = weights.head.apply(x)
x = self.unpatchify(x, grid_sizes)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del e, grid_sizes
torch.cuda.empty_cache() torch.cuda.empty_cache()
return [u.float() for u in x] return [u.float() for u in x]
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from .module_io import WanPreInferModuleOutput
from .utils import guidance_scale_embedding, rope_params, sinusoidal_embedding_1d from .utils import guidance_scale_embedding, rope_params, sinusoidal_embedding_1d
...@@ -132,8 +133,13 @@ class WanPreInfer: ...@@ -132,8 +133,13 @@ class WanPreInfer:
if self.config.get("use_image_encoder", True): if self.config.get("use_image_encoder", True):
del context_clip del context_clip
torch.cuda.empty_cache() torch.cuda.empty_cache()
return (
embed, return WanPreInferModuleOutput(
grid_sizes, embed=embed,
(x.squeeze(0), embed0.squeeze(0), seq_lens, self.freqs, context), grid_sizes=grid_sizes,
x=x.squeeze(0),
embed0=embed0.squeeze(0),
seq_lens=seq_lens,
freqs=self.freqs,
context=context,
) )
...@@ -9,7 +9,7 @@ from lightx2v.common.offload.manager import ( ...@@ -9,7 +9,7 @@ from lightx2v.common.offload.manager import (
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio, compute_freqs_audio_dist, compute_freqs_dist
class WanTransformerInfer(BaseTransformerInfer): class WanTransformerInfer(BaseTransformerInfer):
...@@ -33,7 +33,11 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -33,7 +33,11 @@ class WanTransformerInfer(BaseTransformerInfer):
self.infer_dtype = GET_DTYPE() self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None self.seq_p_group = None
if self.config.get("cpu_offload", False): if self.config.get("cpu_offload", False):
if torch.cuda.get_device_capability(0) == (9, 0): if torch.cuda.get_device_capability(0) == (9, 0):
assert self.config["self_attn_1_type"] != "sage_attn2" assert self.config["self_attn_1_type"] != "sage_attn2"
...@@ -86,6 +90,12 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -86,6 +90,12 @@ class WanTransformerInfer(BaseTransformerInfer):
return cu_seqlens_q, cu_seqlens_k return cu_seqlens_q, cu_seqlens_k
def compute_freqs(self, q, grid_sizes, freqs): def compute_freqs(self, q, grid_sizes, freqs):
if self.config["seq_parallel"]:
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs, self.seq_p_group)
else:
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs, self.seq_p_group)
else:
if "audio" in self.config.get("model_cls", ""): if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs) freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else: else:
...@@ -93,8 +103,43 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -93,8 +103,43 @@ class WanTransformerInfer(BaseTransformerInfer):
return freqs_i return freqs_i
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def infer(self, weights, pre_infer_out):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks) x = self.infer_func(
weights,
pre_infer_out.grid_sizes,
pre_infer_out.embed,
pre_infer_out.x,
pre_infer_out.embed0,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
pre_infer_out.context,
pre_infer_out.audio_dit_blocks,
)
return self._infer_post_blocks(weights, x, pre_infer_out.embed)
def _infer_post_blocks(self, weights, x, e):
if e.dim() == 2:
modulation = weights.head_modulation.tensor # 1, 2, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
elif e.dim() == 3: # For Diffustion forcing
modulation = weights.head_modulation.tensor.unsqueeze(2) # 1, 2, seq, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e]
x = weights.norm.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype)
x.mul_(1 + e[1].squeeze()).add_(e[0].squeeze())
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.infer_dtype)
x = weights.head.apply(x)
if self.clean_cuda_cache:
del e
torch.cuda.empty_cache()
return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
......
import torch import torch
import torch.distributed as dist
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -71,6 +72,52 @@ def compute_freqs_audio(c, grid_sizes, freqs): ...@@ -71,6 +72,52 @@ def compute_freqs_audio(c, grid_sizes, freqs):
return freqs_i return freqs_i
def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group):
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def compute_freqs_audio_dist(s, c, grid_sizes, freqs, seq_p_group):
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
valid_token_length = f * h * w
f = f + 1
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i[valid_token_length:, :, :f] = 0
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0): def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0] f, h, w = grid_sizes[0]
......
...@@ -3,11 +3,11 @@ import os ...@@ -3,11 +3,11 @@ import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F
from loguru import logger from loguru import logger
from safetensors import safe_open from safetensors import safe_open
from lightx2v.common.ops.attn import MaskMap from lightx2v.common.ops.attn import MaskMap
from lightx2v.models.networks.wan.infer.dist_infer.transformer_infer import WanTransformerDistInfer
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import ( from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
WanTransformerInferAdaCaching, WanTransformerInferAdaCaching,
WanTransformerInferCustomCaching, WanTransformerInferCustomCaching,
...@@ -83,9 +83,7 @@ class WanModel: ...@@ -83,9 +83,7 @@ class WanModel:
def _init_infer_class(self): def _init_infer_class(self):
self.pre_infer_class = WanPreInfer self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer self.post_infer_class = WanPostInfer
if self.seq_p_group is not None:
self.transformer_infer_class = WanTransformerDistInfer
else:
if self.config["feature_caching"] == "NoCaching": if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = WanTransformerInfer self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea": elif self.config["feature_caching"] == "Tea":
...@@ -293,17 +291,8 @@ class WanModel: ...@@ -293,17 +291,8 @@ class WanModel:
def _init_infer(self): def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config) self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config) self.post_infer = self.post_infer_class(self.config)
if self.seq_p_group is not None:
self.transformer_infer = self.transformer_infer_class(self.config, self.seq_p_group)
else:
self.transformer_infer = self.transformer_infer_class(self.config) self.transformer_infer = self.transformer_infer_class(self.config)
if self.config["cfg_parallel"]:
self.infer_func = self.infer_with_cfg_parallel
else:
self.infer_func = self.infer_wo_cfg_parallel
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler) self.pre_infer.set_scheduler(scheduler)
...@@ -322,10 +311,6 @@ class WanModel: ...@@ -322,10 +311,6 @@ class WanModel:
@torch.no_grad() @torch.no_grad()
def infer(self, inputs): def infer(self, inputs):
return self.infer_func(inputs)
@torch.no_grad()
def infer_wo_cfg_parallel(self, inputs):
if self.cpu_offload: if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0: if self.offload_granularity == "model" and self.scheduler.step_index == 0:
self.to_cuda() self.to_cuda()
...@@ -338,26 +323,31 @@ class WanModel: ...@@ -338,26 +323,31 @@ class WanModel:
video_token_num = c * (h // 2) * (w // 2) video_token_num = c * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, c) self.transformer_infer.mask_map = MaskMap(video_token_num, c)
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_cond
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_cond, grid_sizes
torch.cuda.empty_cache()
if self.config["enable_cfg"]: if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False) if self.config["cfg_parallel"]:
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) # ==================== CFG Parallel Processing ====================
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
assert dist.get_world_size(cfg_p_group) == 2, "cfg_p_world_size must be equal to 2"
cfg_p_rank = dist.get_rank(cfg_p_group)
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond) if cfg_p_rank == 0:
noise_pred = self._infer_cond_uncond(inputs, positive=True)
else:
noise_pred = self._infer_cond_uncond(inputs, positive=False)
if self.clean_cuda_cache: noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)]
del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group)
torch.cuda.empty_cache() noise_pred_cond = noise_pred_list[0] # cfg_p_rank == 0
noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1
else:
# ==================== CFG Processing ====================
noise_pred_cond = self._infer_cond_uncond(inputs, positive=True)
noise_pred_uncond = self._infer_cond_uncond(inputs, positive=False)
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
else:
# ==================== No CFG ====================
self.scheduler.noise_pred = self._infer_cond_uncond(inputs, positive=True)
if self.cpu_offload: if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1: if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
...@@ -367,24 +357,62 @@ class WanModel: ...@@ -367,24 +357,62 @@ class WanModel:
self.post_weight.to_cpu() self.post_weight.to_cpu()
@torch.no_grad() @torch.no_grad()
def infer_with_cfg_parallel(self, inputs): def _infer_cond_uncond(self, inputs, positive=True):
assert self.config["enable_cfg"], "enable_cfg must be True" pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=positive)
cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
assert dist.get_world_size(cfg_p_group) == 2, f"cfg_p_world_size must be equal to 2"
cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0: if self.config["seq_parallel"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True) pre_infer_out = self._seq_parallel_pre_process(pre_infer_out)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
else:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)] x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out)
dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group)
noise_pred_cond = noise_pred_list[0] # cfg_p_rank == 0 if self.config["seq_parallel"]:
noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1 x = self._seq_parallel_post_process(x)
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
noise_pred = self.post_infer.infer(self.post_weight, x, pre_infer_out)[0]
if self.clean_cuda_cache:
del x, pre_infer_out
torch.cuda.empty_cache()
return noise_pred
@torch.no_grad()
def _seq_parallel_pre_process(self, pre_infer_out):
embed, x, embed0 = pre_infer_out.embed, pre_infer_out.x, pre_infer_out.embed0
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
padding_size = (world_size - (x.shape[0] % world_size)) % world_size
if padding_size > 0:
# 使用 F.pad 填充第一维
x = F.pad(x, (0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
x = torch.chunk(x, world_size, dim=0)[cur_rank]
if self.config["model_cls"].startswith("wan2.2"):
padding_size = (world_size - (embed0.shape[0] % world_size)) % world_size
if padding_size > 0:
embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
embed = F.pad(embed, (0, 0, 0, padding_size))
pre_infer_out.x = x
pre_infer_out.embed = embed
pre_infer_out.embed0 = embed0
return pre_infer_out
@torch.no_grad()
def _seq_parallel_post_process(self, x):
world_size = dist.get_world_size(self.seq_p_group)
# 创建一个列表,用于存储所有进程的输出
gathered_x = [torch.empty_like(x) for _ in range(world_size)]
# 收集所有进程的输出
dist.all_gather(gathered_x, x, group=self.seq_p_group)
# 在指定的维度上合并所有进程的输出
combined_output = torch.cat(gathered_x, dim=0)
return combined_output # 返回合并后的输出
from lightx2v.common.modules.weight_module import WeightModule from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import (
LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
TENSOR_REGISTER,
)
class WanPostWeights(WeightModule): class WanPostWeights(WeightModule):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.register_parameter(
"norm",
LN_WEIGHT_REGISTER["Default"](),
)
self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))
...@@ -26,6 +26,11 @@ class WanTransformerWeights(WeightModule): ...@@ -26,6 +26,11 @@ class WanTransformerWeights(WeightModule):
self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)]) self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)])
self.add_module("blocks", self.blocks) self.add_module("blocks", self.blocks)
# post blocks weights
self.register_parameter("norm", LN_WEIGHT_REGISTER["Default"]())
self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))
def clear(self): def clear(self):
for block in self.blocks: for block in self.blocks:
for phase in block.compute_phases: for phase in block.compute_phases:
......
# LightX2V Server
## Overview
The LightX2V server is a distributed video generation service built with FastAPI that processes image-to-video tasks using a multi-process architecture with GPU support. It implements a sophisticated task queue system with distributed inference capabilities for high-throughput video generation workloads.
## Architecture
### System Architecture
```mermaid
graph TB
subgraph "Client Layer"
Client[HTTP Client]
end
subgraph "API Layer"
FastAPI[FastAPI Application]
ApiServer[ApiServer]
Router1[Tasks Router<br/>/v1/tasks]
Router2[Files Router<br/>/v1/files]
Router3[Service Router<br/>/v1/service]
end
subgraph "Service Layer"
TaskManager[TaskManager<br/>Thread-safe Task Queue]
FileService[FileService<br/>File I/O & Downloads]
VideoService[VideoGenerationService]
end
subgraph "Processing Layer"
Thread[Processing Thread<br/>Sequential Task Loop]
end
subgraph "Distributed Inference Layer"
DistService[DistributedInferenceService]
SharedData[(Shared Data<br/>mp.Manager.dict)]
TaskEvent[Task Event<br/>mp.Manager.Event]
ResultEvent[Result Event<br/>mp.Manager.Event]
subgraph "Worker Processes"
W0[Worker 0<br/>Master/Rank 0]
W1[Worker 1<br/>Rank 1]
WN[Worker N<br/>Rank N]
end
end
subgraph "Resource Management"
GPUManager[GPUManager<br/>GPU Detection & Allocation]
DistManager[DistributedManager<br/>PyTorch Distributed]
Config[ServerConfig<br/>Configuration]
end
Client -->|HTTP Request| FastAPI
FastAPI --> ApiServer
ApiServer --> Router1
ApiServer --> Router2
ApiServer --> Router3
Router1 -->|Create/Manage Tasks| TaskManager
Router1 -->|Process Tasks| Thread
Router2 -->|File Operations| FileService
Router3 -->|Service Status| TaskManager
Thread -->|Get Pending Tasks| TaskManager
Thread -->|Generate Video| VideoService
VideoService -->|Download Images| FileService
VideoService -->|Submit Task| DistService
DistService -->|Update| SharedData
DistService -->|Signal| TaskEvent
TaskEvent -->|Notify| W0
W0 -->|Broadcast| W1
W0 -->|Broadcast| WN
W0 -->|Update Result| SharedData
W0 -->|Signal| ResultEvent
ResultEvent -->|Notify| DistService
W0 -.->|Uses| GPUManager
W1 -.->|Uses| GPUManager
WN -.->|Uses| GPUManager
W0 -.->|Setup| DistManager
W1 -.->|Setup| DistManager
WN -.->|Setup| DistManager
DistService -.->|Reads| Config
ApiServer -.->|Reads| Config
```
## Task Processing Flow
```mermaid
sequenceDiagram
participant C as Client
participant API as API Server
participant TM as TaskManager
participant PT as Processing Thread
participant VS as VideoService
participant FS as FileService
participant DIS as Distributed<br/>Inference Service
participant W0 as Worker 0<br/>(Master)
participant W1 as Worker 1..N
C->>API: POST /v1/tasks<br/>(Create Task)
API->>TM: create_task()
TM->>TM: Generate task_id
TM->>TM: Add to queue<br/>(status: PENDING)
API->>PT: ensure_processing_thread()
API-->>C: TaskResponse<br/>(task_id, status: pending)
Note over PT: Processing Loop
PT->>TM: get_next_pending_task()
TM-->>PT: task_id
PT->>TM: acquire_processing_lock()
PT->>TM: start_task()<br/>(status: PROCESSING)
PT->>VS: generate_video_with_stop_event()
alt Image is URL
VS->>FS: download_image()
FS->>FS: HTTP download<br/>with retry
FS-->>VS: image_path
else Image is Base64
VS->>FS: save_base64_image()
FS-->>VS: image_path
else Image is Upload
VS->>FS: validate_file()
FS-->>VS: image_path
end
VS->>DIS: submit_task(task_data)
DIS->>DIS: shared_data["current_task"] = task_data
DIS->>DIS: task_event.set()
Note over W0,W1: Distributed Processing
W0->>W0: task_event.wait()
W0->>W0: Get task from shared_data
W0->>W1: broadcast_task_data()
par Parallel Inference
W0->>W0: run_pipeline()
and
W1->>W1: run_pipeline()
end
W0->>W0: barrier() for sync
W0->>W0: shared_data["result"] = result
W0->>DIS: result_event.set()
DIS->>DIS: result_event.wait()
DIS->>VS: return result
VS-->>PT: TaskResponse
PT->>TM: complete_task()<br/>(status: COMPLETED)
PT->>TM: release_processing_lock()
Note over C: Client Polling
C->>API: GET /v1/tasks/{task_id}/status
API->>TM: get_task_status()
TM-->>API: status info
API-->>C: Task Status
C->>API: GET /v1/tasks/{task_id}/result
API->>TM: get_task_status()
API->>FS: stream_file_response()
FS-->>API: Video Stream
API-->>C: Video File
```
## Task States
```mermaid
stateDiagram-v2
[*] --> PENDING: create_task()
PENDING --> PROCESSING: start_task()
PROCESSING --> COMPLETED: complete_task()
PROCESSING --> FAILED: fail_task()
PENDING --> CANCELLED: cancel_task()
PROCESSING --> CANCELLED: cancel_task()
COMPLETED --> [*]
FAILED --> [*]
CANCELLED --> [*]
```
## Configuration
### Environment Variables
see `lightx2v/server/config.py`
### Command Line Arguments
```bash
python -m lightx2v.server.main \
--model_path /path/to/model \
--model_cls wan2.1_distill \
--task i2v \
--host 0.0.0.0 \
--port 8000 \
--config_json /path/to/xxx_config.json
```
```bash
python -m lightx2v.server.main \
--model_path /path/to/model \
--model_cls wan2.1_distill \
--task i2v \
--host 0.0.0.0 \
--port 8000 \
--config_json /path/to/xxx_dist_config.json \
--nproc_per_node 2
```
## Key Features
### 1. Distributed Processing
- **Multi-process architecture** for GPU parallelization
- **Master-worker pattern** with rank 0 as coordinator
- **PyTorch distributed** backend (NCCL for GPU, Gloo for CPU)
- **Automatic GPU allocation** across processes
- **Task broadcasting** with chunked pickle serialization
### 2. Task Queue Management
- **Thread-safe** task queue with locks
- **Sequential processing** with single processing thread
- **Configurable queue limits** with overflow protection
- **Task prioritization** (FIFO)
- **Automatic cleanup** of old completed tasks
- **Cancellation support** for pending and running tasks
### 3. File Management
- **Multiple input formats**: URL, base64, file upload
- **HTTP downloads** with exponential backoff retry
- **Streaming responses** for large video files
- **Cache management** with automatic cleanup
- **File validation** and format detection
## Performance Considerations
1. **Single Task Processing**: Tasks are processed sequentially to manage GPU memory effectively
2. **Multi-GPU Support**: Distributes inference across available GPUs for parallelization
3. **Connection Pooling**: Reuses HTTP connections to reduce overhead
4. **Streaming Responses**: Large files are streamed to avoid memory issues
5. **Queue Management**: Automatic task cleanup prevents memory leaks
6. **Process Isolation**: Distributed workers run in separate processes for stability
## Monitoring and Debugging
### Logging
The server uses `loguru` for structured logging. Logs include:
- Request/response details
- Task lifecycle events
- Worker process status
- Error traces with context
### Health Checks
- `/v1/service/status` - Overall service health
- `/v1/tasks/queue/status` - Queue status and processing state
- Process monitoring via system tools (htop, nvidia-smi)
### Common Issues
1. **GPU Out of Memory**: Reduce `nproc_per_node` or adjust model batch size
2. **Task Timeout**: Increase `LIGHTX2V_TASK_TIMEOUT` for longer videos
3. **Queue Full**: Increase `LIGHTX2V_MAX_QUEUE_SIZE` or add rate limiting
## Security Considerations
1. **Input Validation**: All inputs validated with Pydantic schemas
## License
See the main project LICENSE file for licensing information.
import asyncio import asyncio
import gc import gc
import threading import threading
import time
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Any, Optional
from urllib.parse import urlparse
import httpx
import torch import torch
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, UploadFile from fastapi import APIRouter, FastAPI, File, Form, HTTPException, UploadFile
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from loguru import logger from loguru import logger
from .schema import ( from .schema import (
ServiceStatusResponse,
StopTaskResponse, StopTaskResponse,
TaskRequest, TaskRequest,
TaskResponse, TaskResponse,
) )
from .service import DistributedInferenceService, FileService, VideoGenerationService from .service import DistributedInferenceService, FileService, VideoGenerationService
from .utils import ServiceStatus from .task_manager import TaskStatus, task_manager
class ApiServer: class ApiServer:
def __init__(self): def __init__(self, max_queue_size: int = 10, app: Optional[FastAPI] = None):
self.app = FastAPI(title="LightX2V API", version="1.0.0") self.app = app or FastAPI(title="LightX2V API", version="1.0.0")
self.file_service = None self.file_service = None
self.inference_service = None self.inference_service = None
self.video_service = None self.video_service = None
self.thread = None self.max_queue_size = max_queue_size
self.stop_generation_event = threading.Event()
self.processing_thread = None
self.stop_processing = threading.Event()
# Create routers
self.tasks_router = APIRouter(prefix="/v1/tasks", tags=["tasks"]) self.tasks_router = APIRouter(prefix="/v1/tasks", tags=["tasks"])
self.files_router = APIRouter(prefix="/v1/files", tags=["files"]) self.files_router = APIRouter(prefix="/v1/files", tags=["files"])
self.service_router = APIRouter(prefix="/v1/service", tags=["service"]) self.service_router = APIRouter(prefix="/v1/service", tags=["service"])
...@@ -37,7 +40,6 @@ class ApiServer: ...@@ -37,7 +40,6 @@ class ApiServer:
self._setup_routes() self._setup_routes()
def _setup_routes(self): def _setup_routes(self):
"""Setup routes"""
self._setup_task_routes() self._setup_task_routes()
self._setup_file_routes() self._setup_file_routes()
self._setup_service_routes() self._setup_service_routes()
...@@ -48,18 +50,15 @@ class ApiServer: ...@@ -48,18 +50,15 @@ class ApiServer:
self.app.include_router(self.service_router) self.app.include_router(self.service_router)
def _write_file_sync(self, file_path: Path, content: bytes) -> None: def _write_file_sync(self, file_path: Path, content: bytes) -> None:
"""同步写入文件到指定路径"""
with open(file_path, "wb") as buffer: with open(file_path, "wb") as buffer:
buffer.write(content) buffer.write(content)
def _stream_file_response(self, file_path: Path, filename: str | None = None) -> StreamingResponse: def _stream_file_response(self, file_path: Path, filename: str | None = None) -> StreamingResponse:
"""Common file streaming response method"""
assert self.file_service is not None, "File service is not initialized" assert self.file_service is not None, "File service is not initialized"
try: try:
resolved_path = file_path.resolve() resolved_path = file_path.resolve()
# Security check: ensure file is within allowed directory
if not str(resolved_path).startswith(str(self.file_service.output_video_dir.resolve())): if not str(resolved_path).startswith(str(self.file_service.output_video_dir.resolve())):
raise HTTPException(status_code=403, detail="Access to this file is not allowed") raise HTTPException(status_code=403, detail="Access to this file is not allowed")
...@@ -103,24 +102,25 @@ class ApiServer: ...@@ -103,24 +102,25 @@ class ApiServer:
async def create_task(message: TaskRequest): async def create_task(message: TaskRequest):
"""Create video generation task""" """Create video generation task"""
try: try:
task_id = ServiceStatus.start_task(message) if hasattr(message, "image_path") and message.image_path and message.image_path.startswith("http"):
if not await self._validate_image_url(message.image_path):
# Use background thread to handle long-running tasks raise HTTPException(status_code=400, detail=f"Image URL is not accessible: {message.image_path}")
self.stop_generation_event.clear()
self.thread = threading.Thread( task_id = task_manager.create_task(message)
target=self._process_video_generation, message.task_id = task_id
args=(message, self.stop_generation_event),
daemon=True, self._ensure_processing_thread_running()
)
self.thread.start()
return TaskResponse( return TaskResponse(
task_id=task_id, task_id=task_id,
task_status="processing", task_status="pending",
save_video_path=message.save_video_path, save_video_path=message.save_video_path,
) )
except RuntimeError as e: except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=503, detail=str(e))
except Exception as e:
logger.error(f"Failed to create task: {e}")
raise HTTPException(status_code=500, detail=str(e))
@self.tasks_router.post("/form", response_model=TaskResponse) @self.tasks_router.post("/form", response_model=TaskResponse)
async def create_task_form( async def create_task_form(
...@@ -136,11 +136,9 @@ class ApiServer: ...@@ -136,11 +136,9 @@ class ApiServer:
audio_file: Optional[UploadFile] = File(default=None), audio_file: Optional[UploadFile] = File(default=None),
video_duration: int = Form(default=5), video_duration: int = Form(default=5),
): ):
"""Create video generation task via form"""
assert self.file_service is not None, "File service is not initialized" assert self.file_service is not None, "File service is not initialized"
async def save_file_async(file: UploadFile, target_dir: Path) -> str: async def save_file_async(file: UploadFile, target_dir: Path) -> str:
"""异步保存文件到指定目录"""
if not file or not file.filename: if not file or not file.filename:
return "" return ""
...@@ -177,44 +175,58 @@ class ApiServer: ...@@ -177,44 +175,58 @@ class ApiServer:
) )
try: try:
task_id = ServiceStatus.start_task(message) task_id = task_manager.create_task(message)
self.stop_generation_event.clear() message.task_id = task_id
self.thread = threading.Thread(
target=self._process_video_generation, self._ensure_processing_thread_running()
args=(message, self.stop_generation_event),
daemon=True,
)
self.thread.start()
return TaskResponse( return TaskResponse(
task_id=task_id, task_id=task_id,
task_status="processing", task_status="pending",
save_video_path=message.save_video_path, save_video_path=message.save_video_path,
) )
except RuntimeError as e: except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=503, detail=str(e))
except Exception as e:
logger.error(f"Failed to create form task: {e}")
raise HTTPException(status_code=500, detail=str(e))
@self.tasks_router.get("/", response_model=dict) @self.tasks_router.get("/", response_model=dict)
async def list_tasks(): async def list_tasks():
"""Get all task list""" return task_manager.get_all_tasks()
return ServiceStatus.get_all_tasks()
@self.tasks_router.get("/queue/status", response_model=dict)
async def get_queue_status():
service_status = task_manager.get_service_status()
return {
"is_processing": task_manager.is_processing(),
"current_task": service_status.get("current_task"),
"pending_count": task_manager.get_pending_task_count(),
"active_count": task_manager.get_active_task_count(),
"queue_size": self.max_queue_size,
"queue_available": self.max_queue_size - task_manager.get_active_task_count(),
}
@self.tasks_router.get("/{task_id}/status") @self.tasks_router.get("/{task_id}/status")
async def get_task_status(task_id: str): async def get_task_status(task_id: str):
"""Get status of specified task""" status = task_manager.get_task_status(task_id)
return ServiceStatus.get_status_task_id(task_id) if not status:
raise HTTPException(status_code=404, detail="Task not found")
return status
@self.tasks_router.get("/{task_id}/result") @self.tasks_router.get("/{task_id}/result")
async def get_task_result(task_id: str): async def get_task_result(task_id: str):
"""Get result video file of specified task"""
assert self.video_service is not None, "Video service is not initialized" assert self.video_service is not None, "Video service is not initialized"
assert self.file_service is not None, "File service is not initialized" assert self.file_service is not None, "File service is not initialized"
try: try:
task_status = ServiceStatus.get_status_task_id(task_id) task_status = task_manager.get_task_status(task_id)
if not task_status:
raise HTTPException(status_code=404, detail="Task not found")
if not task_status or task_status.get("status") != "completed": if task_status.get("status") != TaskStatus.COMPLETED.value:
raise HTTPException(status_code=404, detail="Task not completed or does not exist") raise HTTPException(status_code=404, detail="Task not completed")
save_video_path = task_status.get("save_video_path") save_video_path = task_status.get("save_video_path")
if not save_video_path: if not save_video_path:
...@@ -232,38 +244,37 @@ class ApiServer: ...@@ -232,38 +244,37 @@ class ApiServer:
logger.error(f"Error occurred while getting task result: {e}") logger.error(f"Error occurred while getting task result: {e}")
raise HTTPException(status_code=500, detail="Failed to get task result") raise HTTPException(status_code=500, detail="Failed to get task result")
@self.tasks_router.delete("/running", response_model=StopTaskResponse) @self.tasks_router.delete("/{task_id}", response_model=StopTaskResponse)
async def stop_running_task(): async def stop_task(task_id: str):
"""Stop currently running task"""
if self.thread and self.thread.is_alive():
try: try:
logger.info("Sending stop signal to running task thread...") if task_manager.cancel_task(task_id):
self.stop_generation_event.set()
self.thread.join(timeout=5)
if self.thread.is_alive():
logger.warning("Task thread did not stop within the specified time, manual intervention may be required.")
return StopTaskResponse(
stop_status="warning",
reason="Task thread did not stop within the specified time, manual intervention may be required.",
)
else:
self.thread = None
ServiceStatus.clean_stopped_task()
gc.collect() gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
logger.info("Task stopped successfully.") logger.info(f"Task {task_id} stopped successfully.")
return StopTaskResponse(stop_status="success", reason="Task stopped successfully.") return StopTaskResponse(stop_status="success", reason="Task stopped successfully.")
else:
return StopTaskResponse(stop_status="do_nothing", reason="Task not found or already completed.")
except Exception as e: except Exception as e:
logger.error(f"Error occurred while stopping task: {str(e)}") logger.error(f"Error occurred while stopping task {task_id}: {str(e)}")
return StopTaskResponse(stop_status="error", reason=str(e))
@self.tasks_router.delete("/all/running", response_model=StopTaskResponse)
async def stop_all_running_tasks():
try:
task_manager.cancel_all_tasks()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("All tasks stopped successfully.")
return StopTaskResponse(stop_status="success", reason="All tasks stopped successfully.")
except Exception as e:
logger.error(f"Error occurred while stopping all tasks: {str(e)}")
return StopTaskResponse(stop_status="error", reason=str(e)) return StopTaskResponse(stop_status="error", reason=str(e))
else:
return StopTaskResponse(stop_status="do_nothing", reason="No running task found.")
def _setup_file_routes(self): def _setup_file_routes(self):
@self.files_router.get("/download/{file_path:path}") @self.files_router.get("/download/{file_path:path}")
async def download_file(file_path: str): async def download_file(file_path: str):
"""Download file"""
assert self.file_service is not None, "File service is not initialized" assert self.file_service is not None, "File service is not initialized"
try: try:
...@@ -276,36 +287,111 @@ class ApiServer: ...@@ -276,36 +287,111 @@ class ApiServer:
raise HTTPException(status_code=500, detail="File download failed") raise HTTPException(status_code=500, detail="File download failed")
def _setup_service_routes(self): def _setup_service_routes(self):
@self.service_router.get("/status", response_model=ServiceStatusResponse) @self.service_router.get("/status", response_model=dict)
async def get_service_status(): async def get_service_status():
"""Get service status""" return task_manager.get_service_status()
return ServiceStatus.get_status_service()
@self.service_router.get("/metadata", response_model=dict) @self.service_router.get("/metadata", response_model=dict)
async def get_service_metadata(): async def get_service_metadata():
"""Get service metadata"""
assert self.inference_service is not None, "Inference service is not initialized" assert self.inference_service is not None, "Inference service is not initialized"
return self.inference_service.server_metadata() return self.inference_service.server_metadata()
def _process_video_generation(self, message: TaskRequest, stop_event: threading.Event): async def _validate_image_url(self, image_url: str) -> bool:
if not image_url or not image_url.startswith("http"):
return True
try:
parsed_url = urlparse(image_url)
if not parsed_url.scheme or not parsed_url.netloc:
return False
timeout = httpx.Timeout(connect=5.0, read=5.0)
async with httpx.AsyncClient(verify=False, timeout=timeout) as client:
response = await client.head(image_url, follow_redirects=True)
return response.status_code < 400
except Exception as e:
logger.warning(f"URL validation failed for {image_url}: {str(e)}")
return False
def _ensure_processing_thread_running(self):
"""Ensure the processing thread is running."""
if self.processing_thread is None or not self.processing_thread.is_alive():
self.stop_processing.clear()
self.processing_thread = threading.Thread(target=self._task_processing_loop, daemon=True)
self.processing_thread.start()
logger.info("Started task processing thread")
def _task_processing_loop(self):
"""Main loop that processes tasks from the queue one by one."""
logger.info("Task processing loop started")
while not self.stop_processing.is_set():
task_id = task_manager.get_next_pending_task()
if task_id is None:
time.sleep(1)
continue
task_info = task_manager.get_task(task_id)
if task_info and task_info.status == TaskStatus.PENDING:
logger.info(f"Processing task {task_id}")
self._process_single_task(task_info)
logger.info("Task processing loop stopped")
def _process_single_task(self, task_info: Any):
"""Process a single task."""
assert self.video_service is not None, "Video service is not initialized" assert self.video_service is not None, "Video service is not initialized"
task_id = task_info.task_id
message = task_info.message
lock_acquired = task_manager.acquire_processing_lock(task_id, timeout=1)
if not lock_acquired:
logger.error(f"Task {task_id} failed to acquire processing lock")
task_manager.fail_task(task_id, "Failed to acquire processing lock")
return
try: try:
if stop_event.is_set(): task_manager.start_task(task_id)
logger.info(f"Task {message.task_id} received stop signal, terminating")
ServiceStatus.record_failed_task(message, error="Task stopped") if task_info.stop_event.is_set():
logger.info(f"Task {task_id} cancelled before processing")
task_manager.fail_task(task_id, "Task cancelled")
return return
# Use video generation service to process task result = asyncio.run(self.video_service.generate_video_with_stop_event(message, task_info.stop_event))
result = asyncio.run(self.video_service.generate_video(message))
if result:
task_manager.complete_task(task_id, result.save_video_path)
logger.info(f"Task {task_id} completed successfully")
else:
if task_info.stop_event.is_set():
task_manager.fail_task(task_id, "Task cancelled during processing")
logger.info(f"Task {task_id} cancelled during processing")
else:
task_manager.fail_task(task_id, "Generation failed")
logger.error(f"Task {task_id} generation failed")
except Exception as e: except Exception as e:
logger.error(f"Task {message.task_id} processing failed: {str(e)}") logger.error(f"Task {task_id} processing failed: {str(e)}")
ServiceStatus.record_failed_task(message, error=str(e)) task_manager.fail_task(task_id, str(e))
finally:
if lock_acquired:
task_manager.release_processing_lock(task_id)
def initialize_services(self, cache_dir: Path, inference_service: DistributedInferenceService): def initialize_services(self, cache_dir: Path, inference_service: DistributedInferenceService):
self.file_service = FileService(cache_dir) self.file_service = FileService(cache_dir)
self.inference_service = inference_service self.inference_service = inference_service
self.video_service = VideoGenerationService(self.file_service, inference_service) self.video_service = VideoGenerationService(self.file_service, inference_service)
async def cleanup(self):
self.stop_processing.set()
if self.processing_thread and self.processing_thread.is_alive():
self.processing_thread.join(timeout=5)
if self.file_service:
await self.file_service.cleanup()
def get_app(self) -> FastAPI: def get_app(self) -> FastAPI:
return self.app return self.app
import base64
import os
import re
import uuid
from pathlib import Path
from typing import Optional, Tuple
from loguru import logger
def is_base64_audio(data: str) -> bool:
"""Check if a string is a base64-encoded audio"""
if data.startswith("data:audio/"):
return True
try:
if len(data) % 4 == 0:
base64.b64decode(data, validate=True)
decoded = base64.b64decode(data[:100])
if decoded.startswith(b"ID3"):
return True
if decoded.startswith(b"\xff\xfb") or decoded.startswith(b"\xff\xf3") or decoded.startswith(b"\xff\xf2"):
return True
if decoded.startswith(b"OggS"):
return True
if decoded.startswith(b"RIFF") and b"WAVE" in decoded[:12]:
return True
if decoded.startswith(b"fLaC"):
return True
if decoded[:4] in [b"ftyp", b"\x00\x00\x00\x20", b"\x00\x00\x00\x18"]:
return True
except Exception as e:
logger.warning(f"Error checking base64 audio: {e}")
return False
return False
def extract_base64_data(data: str) -> Tuple[str, Optional[str]]:
"""
Extract base64 data and format from a data URL or plain base64 string
Returns: (base64_data, format)
"""
if data.startswith("data:"):
match = re.match(r"data:audio/(\w+);base64,(.+)", data)
if match:
format_type = match.group(1)
base64_data = match.group(2)
return base64_data, format_type
return data, None
def save_base64_audio(base64_data: str, output_dir: str) -> str:
"""
Save a base64-encoded audio to disk and return the file path
"""
Path(output_dir).mkdir(parents=True, exist_ok=True)
data, format_type = extract_base64_data(base64_data)
file_id = str(uuid.uuid4())
try:
audio_data = base64.b64decode(data)
except Exception as e:
raise ValueError(f"Invalid base64 data: {e}")
if format_type:
ext = format_type
else:
if audio_data.startswith(b"ID3") or audio_data.startswith(b"\xff\xfb") or audio_data.startswith(b"\xff\xf3") or audio_data.startswith(b"\xff\xf2"):
ext = "mp3"
elif audio_data.startswith(b"OggS"):
ext = "ogg"
elif audio_data.startswith(b"RIFF") and b"WAVE" in audio_data[:12]:
ext = "wav"
elif audio_data.startswith(b"fLaC"):
ext = "flac"
elif audio_data[:4] in [b"ftyp", b"\x00\x00\x00\x20", b"\x00\x00\x00\x18"]:
ext = "m4a"
else:
ext = "mp3"
file_path = os.path.join(output_dir, f"{file_id}.{ext}")
with open(file_path, "wb") as f:
f.write(audio_data)
return file_path
import os
from dataclasses import dataclass
from pathlib import Path
from loguru import logger
@dataclass
class ServerConfig:
host: str = "0.0.0.0"
port: int = 8000
max_queue_size: int = 10
master_addr: str = "127.0.0.1"
master_port_range: tuple = (29500, 29600)
task_timeout: int = 300
task_history_limit: int = 1000
http_timeout: int = 30
http_max_retries: int = 3
cache_dir: str = str(Path(__file__).parent.parent / "server_cache")
max_upload_size: int = 500 * 1024 * 1024 # 500MB
@classmethod
def from_env(cls) -> "ServerConfig":
config = cls()
if env_host := os.environ.get("LIGHTX2V_HOST"):
config.host = env_host
if env_port := os.environ.get("LIGHTX2V_PORT"):
try:
config.port = int(env_port)
except ValueError:
logger.warning(f"Invalid port in environment: {env_port}")
if env_queue_size := os.environ.get("LIGHTX2V_MAX_QUEUE_SIZE"):
try:
config.max_queue_size = int(env_queue_size)
except ValueError:
logger.warning(f"Invalid max queue size: {env_queue_size}")
if env_master_addr := os.environ.get("MASTER_ADDR"):
config.master_addr = env_master_addr
if env_cache_dir := os.environ.get("LIGHTX2V_CACHE_DIR"):
config.cache_dir = env_cache_dir
return config
def find_free_master_port(self) -> str:
import socket
for port in range(self.master_port_range[0], self.master_port_range[1]):
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((self.master_addr, port))
logger.info(f"Found free port for master: {port}")
return str(port)
except OSError:
continue
raise RuntimeError(
f"No free port found for master in range {self.master_port_range[0]}-{self.master_port_range[1] - 1} "
f"on address {self.master_addr}. Please adjust 'master_port_range' or free an occupied port."
)
def validate(self) -> bool:
valid = True
if self.max_queue_size <= 0:
logger.error("max_queue_size must be positive")
valid = False
if self.task_timeout <= 0:
logger.error("task_timeout must be positive")
valid = False
return valid
server_config = ServerConfig.from_env()
import os import os
import pickle
from typing import Any, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger from loguru import logger
from .gpu_manager import gpu_manager
class DistributedManager: class DistributedManager:
def __init__(self): def __init__(self):
self.is_initialized = False self.is_initialized = False
self.rank = 0 self.rank = 0
self.world_size = 1 self.world_size = 1
self.device = "cpu"
CHUNK_SIZE = 1024 * 1024
def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool: def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool:
try: try:
...@@ -18,10 +25,12 @@ class DistributedManager: ...@@ -18,10 +25,12 @@ class DistributedManager:
os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port os.environ["MASTER_PORT"] = master_port
dist.init_process_group(backend="nccl", init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size) backend = "nccl" if torch.cuda.is_available() else "gloo"
if torch.cuda.is_available(): # type: ignore dist.init_process_group(backend=backend, init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size)
torch.cuda.set_device(rank) logger.info(f"Setup backend: {backend}")
self.device = gpu_manager.set_device_for_rank(rank, world_size)
self.is_initialized = True self.is_initialized = True
self.rank = rank self.rank = rank
...@@ -46,55 +55,93 @@ class DistributedManager: ...@@ -46,55 +55,93 @@ class DistributedManager:
def barrier(self): def barrier(self):
if self.is_initialized: if self.is_initialized:
if torch.cuda.is_available() and dist.get_backend() == "nccl":
dist.barrier(device_ids=[torch.cuda.current_device()])
else:
dist.barrier() dist.barrier()
def is_rank_zero(self) -> bool: def is_rank_zero(self) -> bool:
return self.rank == 0 return self.rank == 0
def broadcast_task_data(self, task_data=None): # type: ignore def _broadcast_byte_chunks(self, data_bytes: bytes, device: torch.device) -> None:
total_length = len(data_bytes)
num_full_chunks = total_length // self.CHUNK_SIZE
remaining = total_length % self.CHUNK_SIZE
for i in range(num_full_chunks):
start_idx = i * self.CHUNK_SIZE
end_idx = start_idx + self.CHUNK_SIZE
chunk = data_bytes[start_idx:end_idx]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
if remaining:
chunk = data_bytes[-remaining:]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
def _receive_byte_chunks(self, total_length: int, device: torch.device) -> bytes:
if total_length <= 0:
return b""
received = bytearray()
remaining = total_length
while remaining > 0:
chunk_length = min(self.CHUNK_SIZE, remaining)
task_tensor = torch.empty(chunk_length, dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
received.extend(task_tensor.cpu().numpy())
remaining -= chunk_length
return bytes(received)
def broadcast_task_data(self, task_data: Optional[Any] = None) -> Optional[Any]:
if not self.is_initialized: if not self.is_initialized:
return None return None
try:
backend = dist.get_backend() if dist.is_initialized() else "gloo"
except Exception:
backend = "gloo"
if backend == "gloo":
broadcast_device = torch.device("cpu")
else:
broadcast_device = torch.device(self.device if self.device != "cpu" else "cpu")
if self.is_rank_zero(): if self.is_rank_zero():
if task_data is None: if task_data is None:
stop_signal = torch.tensor([1], dtype=torch.int32, device=f"cuda:{self.rank}") stop_signal = torch.tensor([1], dtype=torch.int32).to(broadcast_device)
else: else:
stop_signal = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}") stop_signal = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
dist.broadcast(stop_signal, src=0) dist.broadcast(stop_signal, src=0)
if task_data is not None: if task_data is not None:
import pickle
task_bytes = pickle.dumps(task_data) task_bytes = pickle.dumps(task_data)
task_length = torch.tensor([len(task_bytes)], dtype=torch.int32, device=f"cuda:{self.rank}") task_length = torch.tensor([len(task_bytes)], dtype=torch.int32).to(broadcast_device)
dist.broadcast(task_length, src=0) dist.broadcast(task_length, src=0)
self._broadcast_byte_chunks(task_bytes, broadcast_device)
task_tensor = torch.tensor(list(task_bytes), dtype=torch.uint8, device=f"cuda:{self.rank}")
dist.broadcast(task_tensor, src=0)
return task_data return task_data
else: else:
return None return None
else: else:
stop_signal = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}") stop_signal = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
dist.broadcast(stop_signal, src=0) dist.broadcast(stop_signal, src=0)
if stop_signal.item() == 1: if stop_signal.item() == 1:
return None return None
else: else:
task_length = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}") task_length = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
dist.broadcast(task_length, src=0)
task_tensor = torch.empty(int(task_length.item()), dtype=torch.uint8, device=f"cuda:{self.rank}")
dist.broadcast(task_tensor, src=0)
import pickle dist.broadcast(task_length, src=0)
total_length = int(task_length.item())
task_bytes = bytes(task_tensor.cpu().numpy()) task_bytes = self._receive_byte_chunks(total_length, broadcast_device)
task_data = pickle.loads(task_bytes) task_data = pickle.loads(task_bytes)
return task_data return task_data
...@@ -113,7 +160,6 @@ class DistributedWorker: ...@@ -113,7 +160,6 @@ class DistributedWorker:
self.dist_manager.cleanup() self.dist_manager.cleanup()
def sync_and_report(self, task_id: str, status: str, result_queue, **kwargs): def sync_and_report(self, task_id: str, status: str, result_queue, **kwargs):
# Synchronize all processes
self.dist_manager.barrier() self.dist_manager.barrier()
if self.dist_manager.is_rank_zero(): if self.dist_manager.is_rank_zero():
......
import os
from typing import List, Optional, Tuple
import torch
from loguru import logger
class GPUManager:
def __init__(self):
self.available_gpus = self._detect_gpus()
self.gpu_count = len(self.available_gpus)
def _detect_gpus(self) -> List[int]:
if not torch.cuda.is_available():
logger.warning("No CUDA devices available, will use CPU")
return []
gpu_count = torch.cuda.device_count()
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if cuda_visible:
try:
visible_devices = [int(d.strip()) for d in cuda_visible.split(",")]
logger.info(f"CUDA_VISIBLE_DEVICES set to: {visible_devices}")
return list(range(len(visible_devices)))
except ValueError:
logger.warning(f"Invalid CUDA_VISIBLE_DEVICES: {cuda_visible}, using all devices")
available_gpus = list(range(gpu_count))
logger.info(f"Detected {gpu_count} GPU devices: {available_gpus}")
return available_gpus
def get_device_for_rank(self, rank: int, world_size: int) -> str:
if not self.available_gpus:
logger.info(f"Rank {rank}: Using CPU (no GPUs available)")
return "cpu"
if self.gpu_count == 1:
device = f"cuda:{self.available_gpus[0]}"
logger.info(f"Rank {rank}: Using single GPU {device}")
return device
if self.gpu_count >= world_size:
gpu_id = self.available_gpus[rank % self.gpu_count]
device = f"cuda:{gpu_id}"
logger.info(f"Rank {rank}: Assigned to dedicated GPU {device}")
return device
else:
gpu_id = self.available_gpus[rank % self.gpu_count]
device = f"cuda:{gpu_id}"
logger.info(f"Rank {rank}: Sharing GPU {device} (world_size={world_size} > gpu_count={self.gpu_count})")
return device
def set_device_for_rank(self, rank: int, world_size: int) -> str:
device = self.get_device_for_rank(rank, world_size)
if device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
torch.cuda.set_device(gpu_id)
logger.info(f"Rank {rank}: CUDA device set to {gpu_id}")
return device
def get_memory_info(self, device: Optional[str] = None) -> Tuple[int, int]:
if not torch.cuda.is_available():
return (0, 0)
if device and device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
else:
gpu_id = torch.cuda.current_device()
try:
used = torch.cuda.memory_allocated(gpu_id)
total = torch.cuda.get_device_properties(gpu_id).total_memory
return (used, total)
except Exception as e:
logger.error(f"Failed to get memory info for device {gpu_id}: {e}")
return (0, 0)
def clear_cache(self, device: Optional[str] = None):
if not torch.cuda.is_available():
return
if device and device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
with torch.cuda.device(gpu_id):
torch.cuda.empty_cache()
torch.cuda.synchronize()
else:
torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.info(f"GPU cache cleared for device: {device or 'current'}")
@staticmethod
def get_optimal_world_size(requested_world_size: int) -> int:
if not torch.cuda.is_available():
logger.warning("No GPUs available, using single process")
return 1
gpu_count = torch.cuda.device_count()
if requested_world_size <= 0:
optimal_size = gpu_count
logger.info(f"Auto-detected world_size: {optimal_size} (based on {gpu_count} GPUs)")
elif requested_world_size > gpu_count:
logger.warning(f"Requested world_size ({requested_world_size}) exceeds GPU count ({gpu_count}). Processes will share GPUs.")
optimal_size = requested_world_size
else:
optimal_size = requested_world_size
return optimal_size
gpu_manager = GPUManager()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment