Commit e08c4f90 authored by sandy's avatar sandy Committed by GitHub
Browse files

Merge branch 'main' into audio_r2v

parents 12bfd120 6d07a72e
...@@ -12,6 +12,7 @@ from lightx2v.models.networks.wan.infer.causvid.transformer_infer import ( ...@@ -12,6 +12,7 @@ from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
WanTransformerInferCausVid, WanTransformerInferCausVid,
) )
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from safetensors import safe_open
class WanCausVidModel(WanModel): class WanCausVidModel(WanModel):
...@@ -28,18 +29,22 @@ class WanCausVidModel(WanModel): ...@@ -28,18 +29,22 @@ class WanCausVidModel(WanModel):
self.transformer_infer_class = WanTransformerInferCausVid self.transformer_infer_class = WanTransformerInferCausVid
def _load_ckpt(self, use_bf16, skip_bf16): def _load_ckpt(self, use_bf16, skip_bf16):
use_bfloat16 = GET_DTYPE() == "BF16" ckpt_folder = "causvid_models"
ckpt_path = os.path.join(self.model_path, "causal_model.pt") safetensors_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.safetensors")
if not os.path.exists(ckpt_path): if os.path.exists(safetensors_path):
return super()._load_ckpt(use_bf16, skip_bf16) with safe_open(safetensors_path, framework="pt") as f:
weight_dict = {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) return weight_dict
dtype = torch.bfloat16 if use_bfloat16 else None ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.pt")
for key, value in weight_dict.items(): if os.path.exists(ckpt_path):
weight_dict[key] = value.to(device=self.device, dtype=dtype) weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {
return weight_dict key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
}
return weight_dict
return super()._load_ckpt(use_bf16, skip_bf16)
@torch.no_grad() @torch.no_grad()
def infer(self, inputs, kv_start, kv_end): def infer(self, inputs, kv_start, kv_end):
......
...@@ -64,7 +64,7 @@ class WanPreInfer: ...@@ -64,7 +64,7 @@ class WanPreInfer:
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.enable_dynamic_cfg: if self.enable_dynamic_cfg:
s = torch.tensor([self.cfg_scale], dtype=torch.float32).to(x.device) s = torch.tensor([self.cfg_scale], dtype=torch.float32).to(x.device)
cfg_embed = guidance_scale_embedding(s, embedding_dim=256, cfg_range=(0.0, 8.0), target_range=1000.0, dtype=torch.float32).type_as(x) cfg_embed = guidance_scale_embedding(s, embedding_dim=256, cfg_range=(1.0, 8.0), target_range=1000.0, dtype=torch.float32).type_as(x)
cfg_embed = weights.cfg_cond_proj.apply(cfg_embed) cfg_embed = weights.cfg_cond_proj.apply(cfg_embed)
embed = embed + cfg_embed embed = embed + cfg_embed
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
......
...@@ -29,6 +29,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -29,6 +29,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self.mask_map = None self.mask_map = None
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
if torch.cuda.get_device_capability(0) == (9, 0):
assert self.config["self_attn_1_type"] != "sage_attn2"
if "offload_ratio" in self.config: if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"] offload_ratio = self.config["offload_ratio"]
else: else:
...@@ -104,7 +106,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -104,7 +106,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x return x
def _infer_with_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_with_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
self.weights_stream_mgr.prefetch_weights_from_disk(weights) self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
if block_idx == 0: if block_idx == 0:
...@@ -132,7 +134,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -132,7 +134,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if block_idx == self.blocks_num - 1: if block_idx == self.blocks_num - 1:
self.weights_stream_mgr.pin_memory_buffer.pop_front() self.weights_stream_mgr.pin_memory_buffer.pop_front()
self.weights_stream_mgr._async_prefetch_block(weights) self.weights_stream_mgr._async_prefetch_block(weights.blocks)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del grid_sizes, embed, embed0, seq_lens, freqs, context del grid_sizes, embed, embed0, seq_lens, freqs, context
...@@ -189,7 +191,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -189,7 +191,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x return x
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
self.weights_stream_mgr.prefetch_weights_from_disk(weights) self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
for block_idx in range(weights.blocks_num): for block_idx in range(weights.blocks_num):
for phase_idx in range(self.weights_stream_mgr.phases_num): for phase_idx in range(self.weights_stream_mgr.phases_num):
...@@ -236,7 +238,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -236,7 +238,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr.swap_phases() self.weights_stream_mgr.swap_phases()
self.weights_stream_mgr._async_prefetch_block(weights) self.weights_stream_mgr._async_prefetch_block(weights.blocks)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del attn_out, y_out, y del attn_out, y_out, y
......
import os import os
import sys
import torch import torch
import glob import glob
import json import json
...@@ -37,7 +36,11 @@ class WanModel: ...@@ -37,7 +36,11 @@ class WanModel:
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default" self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", None) if self.dit_quantized:
dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1]
self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", os.path.join(model_path, dit_quant_scheme))
else:
self.dit_quantized_ckpt = None
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False) self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized: if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
...@@ -80,7 +83,12 @@ class WanModel: ...@@ -80,7 +83,12 @@ class WanModel:
safetensors_files = glob.glob(safetensors_pattern) safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files: if not safetensors_files:
raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}") original_pattern = os.path.join(self.model_path, "original", "*.safetensors")
safetensors_files = glob.glob(original_pattern)
if not safetensors_files:
raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16) file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16)
...@@ -138,7 +146,14 @@ class WanModel: ...@@ -138,7 +146,14 @@ class WanModel:
def _init_weights(self, weight_dict=None): def _init_weights(self, weight_dict=None):
use_bf16 = GET_DTYPE() == "BF16" use_bf16 = GET_DTYPE() == "BF16"
# Some layers run with float32 to achieve high accuracy # Some layers run with float32 to achieve high accuracy
skip_bf16 = {"norm", "embedding", "modulation", "time", "img_emb.proj.0", "img_emb.proj.4"} skip_bf16 = {
"norm",
"embedding",
"modulation",
"time",
"img_emb.proj.0",
"img_emb.proj.4",
}
if weight_dict is None: if weight_dict is None:
if not self.dit_quantized or self.weight_auto_quant: if not self.dit_quantized or self.weight_auto_quant:
self.original_weight_dict = self._load_ckpt(use_bf16, skip_bf16) self.original_weight_dict = self._load_ckpt(use_bf16, skip_bf16)
......
...@@ -24,6 +24,11 @@ class WanTransformerWeights(WeightModule): ...@@ -24,6 +24,11 @@ class WanTransformerWeights(WeightModule):
self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)]) self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)])
self.add_module("blocks", self.blocks) self.add_module("blocks", self.blocks)
def clear(self):
for block in self.blocks:
for phase in block.compute_phases:
phase.clear()
class WanTransformerAttentionBlock(WeightModule): class WanTransformerAttentionBlock(WeightModule):
def __init__(self, block_index, task, mm_type, config): def __init__(self, block_index, task, mm_type, config):
......
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple, Optional, Union, List, Protocol
from lightx2v.utils.utils import save_videos_grid
class TransformerModel(Protocol):
"""Protocol for transformer models"""
def set_scheduler(self, scheduler: Any) -> None: ...
def scheduler(self) -> Any: ...
class TextEncoderModel(Protocol):
"""Protocol for text encoder models"""
def infer(self, texts: List[str], config: Dict[str, Any]) -> Any: ...
class ImageEncoderModel(Protocol):
"""Protocol for image encoder models"""
def encode(self, image: Any) -> Any: ...
class VAEModel(Protocol):
"""Protocol for VAE models"""
def encode(self, image: Any) -> Tuple[Any, Dict[str, Any]]: ...
def decode(self, latents: Any, generator: Optional[Any] = None, config: Optional[Dict[str, Any]] = None) -> Any: ...
class BaseRunner(ABC):
"""Abstract base class for all Runners
Defines interface methods that all subclasses must implement
"""
def __init__(self, config: Dict[str, Any]):
self.config = config
@abstractmethod
def load_transformer(self) -> TransformerModel:
"""Load transformer model
Returns:
Loaded transformer model instance
"""
pass
@abstractmethod
def load_text_encoder(self) -> Union[TextEncoderModel, List[TextEncoderModel]]:
"""Load text encoder
Returns:
Text encoder instance or list of text encoder instances
"""
pass
@abstractmethod
def load_image_encoder(self) -> Optional[ImageEncoderModel]:
"""Load image encoder
Returns:
Image encoder instance or None if not needed
"""
pass
@abstractmethod
def load_vae(self) -> Tuple[VAEModel, VAEModel]:
"""Load VAE encoder and decoder
Returns:
Tuple[vae_encoder, vae_decoder]: VAE encoder and decoder instances
"""
pass
@abstractmethod
def run_image_encoder(self, img: Any) -> Any:
"""Run image encoder
Args:
img: Input image
Returns:
Image encoding result
"""
pass
@abstractmethod
def run_vae_encoder(self, img: Any) -> Tuple[Any, Dict[str, Any]]:
"""Run VAE encoder
Args:
img: Input image
Returns:
Tuple of VAE encoding result and additional parameters
"""
pass
@abstractmethod
def run_text_encoder(self, prompt: str, img: Optional[Any] = None) -> Any:
"""Run text encoder
Args:
prompt: Input text prompt
img: Optional input image (for some models)
Returns:
Text encoding result
"""
pass
@abstractmethod
def get_encoder_output_i2v(self, clip_encoder_out: Any, vae_encode_out: Any, text_encoder_output: Any, img: Any) -> Dict[str, Any]:
"""Combine encoder outputs for i2v task
Args:
clip_encoder_out: CLIP encoder output
vae_encode_out: VAE encoder output
text_encoder_output: Text encoder output
img: Original image
Returns:
Combined encoder output dictionary
"""
pass
@abstractmethod
def init_scheduler(self) -> None:
"""Initialize scheduler"""
pass
def set_target_shape(self) -> Dict[str, Any]:
"""Set target shape
Subclasses can override this method to provide specific implementation
Returns:
Dictionary containing target shape information
"""
return {}
def save_video_func(self, images: Any) -> None:
"""Save video implementation
Subclasses can override this method to customize save logic
Args:
images: Image sequence to save
"""
save_videos_grid(images, self.config.get("save_video_path", "./output.mp4"), n_rows=1, fps=self.config.get("fps", 8))
def load_vae_decoder(self) -> VAEModel:
"""Load VAE decoder
Default implementation: get decoder from load_vae method
Subclasses can override this method to provide different loading logic
Returns:
VAE decoder instance
"""
if not hasattr(self, "vae_decoder") or self.vae_decoder is None:
_, self.vae_decoder = self.load_vae()
return self.vae_decoder
import asyncio
import gc import gc
import aiohttp
import requests import requests
from requests.exceptions import RequestException from requests.exceptions import RequestException
import torch import torch
...@@ -13,12 +11,14 @@ from lightx2v.utils.generate_task_id import generate_task_id ...@@ -13,12 +11,14 @@ from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter
from loguru import logger from loguru import logger
from .base_runner import BaseRunner
class DefaultRunner: class DefaultRunner(BaseRunner):
def __init__(self, config): def __init__(self, config):
self.config = config super().__init__(config)
self.has_prompt_enhancer = False self.has_prompt_enhancer = False
self.progress_callback = None
if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None: if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
self.has_prompt_enhancer = True self.has_prompt_enhancer = True
if not self.check_sub_servers("prompt_enhancer"): if not self.check_sub_servers("prompt_enhancer"):
...@@ -30,33 +30,14 @@ class DefaultRunner: ...@@ -30,33 +30,14 @@ class DefaultRunner:
def init_modules(self): def init_modules(self):
logger.info("Initializing runner modules...") logger.info("Initializing runner modules...")
if self.config["mode"] == "split_server": if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.tensor_transporter = TensorTransporter() self.load_model()
self.image_transporter = ImageTransporter() self.run_dit = self._run_dit_local
if not self.check_sub_servers("dit"): self.run_vae_decoder = self._run_vae_decoder_local
raise ValueError("No dit server available") if self.config["task"] == "i2v":
if not self.check_sub_servers("text_encoders"): self.run_input_encoder = self._run_input_encoder_local_i2v
raise ValueError("No text encoder server available")
if self.config["task"] == "i2v":
if not self.check_sub_servers("image_encoder"):
raise ValueError("No image encoder server available")
if not self.check_sub_servers("vae_model"):
raise ValueError("No vae server available")
self.run_dit = self.run_dit_server
self.run_vae_decoder = self.run_vae_decoder_server
if self.config["task"] == "i2v":
self.run_input_encoder = self.run_input_encoder_server_i2v
else:
self.run_input_encoder = self.run_input_encoder_server_t2v
else: else:
if not self.config.get("lazy_load", False): self.run_input_encoder = self._run_input_encoder_local_t2v
self.load_model()
self.run_dit = self.run_dit_local
self.run_vae_decoder = self.run_vae_decoder_local
if self.config["task"] == "i2v":
self.run_input_encoder = self.run_input_encoder_local_i2v
else:
self.run_input_encoder = self.run_input_encoder_local_t2v
def set_init_device(self): def set_init_device(self):
if self.config["parallel_attn_type"]: if self.config["parallel_attn_type"]:
...@@ -110,9 +91,13 @@ class DefaultRunner: ...@@ -110,9 +91,13 @@ class DefaultRunner:
# self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5)) # self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
# self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5)) # self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
def set_progress_callback(self, callback):
self.progress_callback = callback
def run(self): def run(self):
for step_index in range(self.model.scheduler.infer_steps): total_steps = self.model.scheduler.infer_steps
logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}") for step_index in range(total_steps):
logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
with ProfilingContext4Debug("step_pre"): with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index) self.model.scheduler.step_pre(step_index=step_index)
...@@ -123,11 +108,14 @@ class DefaultRunner: ...@@ -123,11 +108,14 @@ class DefaultRunner:
with ProfilingContext4Debug("step_post"): with ProfilingContext4Debug("step_post"):
self.model.scheduler.step_post() self.model.scheduler.step_post()
if self.progress_callback:
self.progress_callback(step_index + 1, total_steps)
return self.model.scheduler.latents, self.model.scheduler.generator return self.model.scheduler.latents, self.model.scheduler.generator
async def run_step(self, step_index=0): def run_step(self, step_index=0):
self.init_scheduler() self.init_scheduler()
await self.run_input_encoder() self.inputs = self.run_input_encoder()
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) self.model.scheduler.prepare(self.inputs["image_encoder_output"])
self.model.scheduler.step_pre(step_index=step_index) self.model.scheduler.step_pre(step_index=step_index)
self.model.infer(self.inputs) self.model.infer(self.inputs)
...@@ -136,14 +124,19 @@ class DefaultRunner: ...@@ -136,14 +124,19 @@ class DefaultRunner:
def end_run(self): def end_run(self):
self.model.scheduler.clear() self.model.scheduler.clear()
del self.inputs, self.model.scheduler del self.inputs, self.model.scheduler
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model.transformer_infer.weights_stream_mgr.clear() if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
self.model.transformer_infer.weights_stream_mgr.clear()
if hasattr(self.model.transformer_weights, "clear"):
self.model.transformer_weights.clear()
self.model.pre_weight.clear()
self.model.post_weight.clear()
del self.model del self.model
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
@ProfilingContext("Run Encoders") @ProfilingContext("Run Encoders")
async def run_input_encoder_local_i2v(self): def _run_input_encoder_local_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img = Image.open(self.config["image_path"]).convert("RGB") img = Image.open(self.config["image_path"]).convert("RGB")
clip_encoder_out = self.run_image_encoder(img) clip_encoder_out = self.run_image_encoder(img)
...@@ -154,16 +147,19 @@ class DefaultRunner: ...@@ -154,16 +147,19 @@ class DefaultRunner:
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img) return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
@ProfilingContext("Run Encoders") @ProfilingContext("Run Encoders")
async def run_input_encoder_local_t2v(self): def _run_input_encoder_local_t2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt, None) text_encoder_output = self.run_text_encoder(prompt, None)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None} return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": None,
}
@ProfilingContext("Run DiT") @ProfilingContext("Run DiT")
async def run_dit_local(self, kwargs): def _run_dit_local(self, kwargs):
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer() self.model = self.load_transformer()
self.init_scheduler() self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) self.model.scheduler.prepare(self.inputs["image_encoder_output"])
...@@ -172,11 +168,11 @@ class DefaultRunner: ...@@ -172,11 +168,11 @@ class DefaultRunner:
return latents, generator return latents, generator
@ProfilingContext("Run VAE Decoder") @ProfilingContext("Run VAE Decoder")
async def run_vae_decoder_local(self, latents, generator): def _run_vae_decoder_local(self, latents, generator):
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder() self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents, generator=generator, config=self.config) images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder del self.vae_decoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
...@@ -187,115 +183,39 @@ class DefaultRunner: ...@@ -187,115 +183,39 @@ class DefaultRunner:
if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0): if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
self.save_video_func(images) self.save_video_func(images)
async def post_task(self, task_type, urls, message, device="cuda"):
while True:
for url in urls:
async with aiohttp.ClientSession() as session:
async with session.get(f"{url}/v1/local/{task_type}/generate/service_status") as response:
status = await response.json()
if status["service_status"] == "idle":
async with session.post(f"{url}/v1/local/{task_type}/generate", json=message) as response:
result = await response.json()
if result["kwargs"] is not None:
for k, v in result["kwargs"].items():
setattr(self.config, k, v)
return self.tensor_transporter.load_tensor(result["output"], device)
await asyncio.sleep(0.1)
def post_prompt_enhancer(self): def post_prompt_enhancer(self):
while True: while True:
for url in self.config["sub_servers"]["prompt_enhancer"]: for url in self.config["sub_servers"]["prompt_enhancer"]:
response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json() response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
if response["service_status"] == "idle": if response["service_status"] == "idle":
response = requests.post(f"{url}/v1/local/prompt_enhancer/generate", json={"task_id": generate_task_id(), "prompt": self.config["prompt"]}) response = requests.post(
f"{url}/v1/local/prompt_enhancer/generate",
json={
"task_id": generate_task_id(),
"prompt": self.config["prompt"],
},
)
enhanced_prompt = response.json()["output"] enhanced_prompt = response.json()["output"]
logger.info(f"Enhanced prompt: {enhanced_prompt}") logger.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt return enhanced_prompt
async def post_encoders_i2v(self, prompt, img=None, n_prompt=None, i2v=False): def run_pipeline(self, save_video=True):
tasks = []
img_byte = self.image_transporter.prepare_image(img)
tasks.append(
asyncio.create_task(self.post_task(task_type="image_encoder", urls=self.config["sub_servers"]["image_encoder"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda"))
)
tasks.append(
asyncio.create_task(self.post_task(task_type="vae_model/encoder", urls=self.config["sub_servers"]["vae_model"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda"))
)
tasks.append(
asyncio.create_task(
self.post_task(
task_type="text_encoders",
urls=self.config["sub_servers"]["text_encoders"],
message={"task_id": generate_task_id(), "text": prompt, "img": img_byte, "n_prompt": n_prompt},
device="cuda",
)
)
)
results = await asyncio.gather(*tasks)
# clip_encoder, vae_encoder, text_encoders
return results[0], results[1], results[2]
async def post_encoders_t2v(self, prompt, n_prompt=None):
tasks = []
tasks.append(
asyncio.create_task(
self.post_task(
task_type="text_encoders",
urls=self.config["sub_servers"]["text_encoders"],
message={"task_id": generate_task_id(), "text": prompt, "img": None, "n_prompt": n_prompt},
device="cuda",
)
)
)
results = await asyncio.gather(*tasks)
# text_encoders
return results[0]
async def run_input_encoder_server_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
n_prompt = self.config.get("negative_prompt", "")
img = Image.open(self.config["image_path"]).convert("RGB")
clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders_i2v(prompt, img, n_prompt)
torch.cuda.empty_cache()
gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
async def run_input_encoder_server_t2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
n_prompt = self.config.get("negative_prompt", "")
text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt)
torch.cuda.empty_cache()
gc.collect()
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}
async def run_dit_server(self, kwargs):
if self.inputs.get("image_encoder_output", None) is not None:
self.inputs["image_encoder_output"].pop("img", None)
dit_output = await self.post_task(
task_type="dit",
urls=self.config["sub_servers"]["dit"],
message={"task_id": generate_task_id(), "inputs": self.tensor_transporter.prepare_tensor(self.inputs), "kwargs": self.tensor_transporter.prepare_tensor(kwargs)},
device="cuda",
)
return dit_output, None
async def run_vae_decoder_server(self, latents, generator):
images = await self.post_task(
task_type="vae_model/decoder",
urls=self.config["sub_servers"]["vae_model"],
message={"task_id": generate_task_id(), "latents": self.tensor_transporter.prepare_tensor(latents)},
device="cpu",
)
return images
async def run_pipeline(self):
if self.config["use_prompt_enhancer"]: if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer() self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = await self.run_input_encoder()
self.inputs = self.run_input_encoder()
kwargs = self.set_target_shape() kwargs = self.set_target_shape()
latents, generator = await self.run_dit(kwargs)
images = await self.run_vae_decoder(latents, generator) latents, generator = self.run_dit(kwargs)
self.save_video(images)
del latents, generator, images images = self.run_vae_decoder(latents, generator)
if save_video:
self.save_video(images)
del latents, generator
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return images
...@@ -329,12 +329,15 @@ class WanAudioRunner(WanRunner): ...@@ -329,12 +329,15 @@ class WanAudioRunner(WanRunner):
def load_transformer(self): def load_transformer(self):
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device) base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
if self.config.lora_path: if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False) assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(base_model) lora_wrapper = WanLoraWrapper(base_model)
lora_name = lora_wrapper.load_lora(self.config.lora_path) for lora_config in self.config.lora_configs:
lora_wrapper.apply_lora(lora_name, self.config.strength_model) lora_path = lora_config["path"]
logger.info(f"Loaded LoRA: {lora_name}") strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
return base_model return base_model
......
...@@ -24,24 +24,26 @@ import torch.distributed as dist ...@@ -24,24 +24,26 @@ import torch.distributed as dist
class WanCausVidRunner(WanRunner): class WanCausVidRunner(WanRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_frame_per_block = self.model.config.num_frame_per_block self.num_frame_per_block = self.config.num_frame_per_block
self.num_frames = self.model.config.num_frames self.num_frames = self.config.num_frames
self.frame_seq_length = self.model.config.frame_seq_length self.frame_seq_length = self.config.frame_seq_length
self.infer_blocks = self.model.config.num_blocks self.infer_blocks = self.config.num_blocks
self.num_fragments = self.model.config.num_fragments self.num_fragments = self.config.num_fragments
def load_transformer(self): def load_transformer(self):
if self.config.lora_path: if self.config.get("lora_configs") and self.config.lora_configs:
model = WanModel( model = WanModel(
self.config.model_path, self.config.model_path,
self.config, self.config,
self.init_device, self.init_device,
) )
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
for lora_path in self.config.lora_path: for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path) lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model) lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name}") logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
else: else:
model = WanCausVidModel(self.config.model_path, self.config, self.init_device) model = WanCausVidModel(self.config.model_path, self.config, self.init_device)
return model return model
......
...@@ -24,17 +24,19 @@ class WanDistillRunner(WanRunner): ...@@ -24,17 +24,19 @@ class WanDistillRunner(WanRunner):
super().__init__(config) super().__init__(config)
def load_transformer(self): def load_transformer(self):
if self.config.lora_path: if self.config.get("lora_configs") and self.config.lora_configs:
model = WanModel( model = WanModel(
self.config.model_path, self.config.model_path,
self.config, self.config,
self.init_device, self.init_device,
) )
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
for lora_path in self.config.lora_path: for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path) lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model) lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name}") logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
else: else:
model = WanDistillModel(self.config.model_path, self.config, self.init_device) model = WanDistillModel(self.config.model_path, self.config, self.init_device)
return model return model
......
...@@ -7,6 +7,9 @@ from PIL import Image ...@@ -7,6 +7,9 @@ from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
WanScheduler4ChangingResolution,
)
from lightx2v.models.schedulers.wan.feature_caching.scheduler import ( from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
WanSchedulerTeaCaching, WanSchedulerTeaCaching,
WanSchedulerTaylorCaching, WanSchedulerTaylorCaching,
...@@ -35,18 +38,36 @@ class WanRunner(DefaultRunner): ...@@ -35,18 +38,36 @@ class WanRunner(DefaultRunner):
self.config, self.config,
self.init_device, self.init_device,
) )
if self.config.lora_path: if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False) assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
for lora_path in self.config.lora_path: for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path) lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model) lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name}") logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
return model return model
def load_image_encoder(self): def load_image_encoder(self):
image_encoder = None image_encoder = None
if self.config.task == "i2v": if self.config.task == "i2v":
# quant_config
clip_quantized = self.config.get("clip_quantized", False)
if clip_quantized:
clip_quant_scheme = self.config.get("clip_quant_scheme", None)
assert clip_quant_scheme is not None
clip_quantized_ckpt = self.config.get(
"clip_quantized_ckpt",
os.path.join(
os.path.join(self.config.model_path, clip_quant_scheme),
f"clip-{clip_quant_scheme}.pth",
),
)
else:
clip_quantized_ckpt = None
clip_quant_scheme = None
image_encoder = CLIPModel( image_encoder = CLIPModel(
dtype=torch.float16, dtype=torch.float16,
device=self.init_device, device=self.init_device,
...@@ -54,25 +75,48 @@ class WanRunner(DefaultRunner): ...@@ -54,25 +75,48 @@ class WanRunner(DefaultRunner):
self.config.model_path, self.config.model_path,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
), ),
clip_quantized=self.config.get("clip_quantized", False), clip_quantized=clip_quantized,
clip_quantized_ckpt=self.config.get("clip_quantized_ckpt", None), clip_quantized_ckpt=clip_quantized_ckpt,
quant_scheme=self.config.get("clip_quant_scheme", None), quant_scheme=clip_quant_scheme,
) )
return image_encoder return image_encoder
def load_text_encoder(self): def load_text_encoder(self):
# offload config
t5_offload = self.config.get("t5_cpu_offload", False)
if t5_offload:
t5_device = torch.device("cpu")
else:
t5_device = torch.device("cuda")
# quant_config
t5_quantized = self.config.get("t5_quantized", False)
if t5_quantized:
t5_quant_scheme = self.config.get("t5_quant_scheme", None)
assert t5_quant_scheme is not None
t5_quantized_ckpt = self.config.get(
"t5_quantized_ckpt",
os.path.join(
os.path.join(self.config.model_path, t5_quant_scheme),
f"models_t5_umt5-xxl-enc-{t5_quant_scheme}.pth",
),
)
else:
t5_quant_scheme = None
t5_quantized_ckpt = None
text_encoder = T5EncoderModel( text_encoder = T5EncoderModel(
text_len=self.config["text_len"], text_len=self.config["text_len"],
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=self.init_device, device=t5_device,
checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"), checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"), tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
shard_fn=None, shard_fn=None,
cpu_offload=self.config.cpu_offload, cpu_offload=t5_offload,
offload_granularity=self.config.get("t5_offload_granularity", "model"), offload_granularity=self.config.get("t5_offload_granularity", "model"),
t5_quantized=self.config.get("t5_quantized", False), t5_quantized=t5_quantized,
t5_quantized_ckpt=self.config.get("t5_quantized_ckpt", None), t5_quantized_ckpt=t5_quantized_ckpt,
quant_scheme=self.config.get("t5_quant_scheme", None), quant_scheme=t5_quant_scheme,
) )
text_encoders = [text_encoder] text_encoders = [text_encoder]
return text_encoders return text_encoders
...@@ -114,28 +158,31 @@ class WanRunner(DefaultRunner): ...@@ -114,28 +158,31 @@ class WanRunner(DefaultRunner):
return vae_encoder, vae_decoder return vae_encoder, vae_decoder
def init_scheduler(self): def init_scheduler(self):
if self.config.feature_caching == "NoCaching": if self.config.get("changing_resolution", False):
scheduler = WanScheduler(self.config) scheduler = WanScheduler4ChangingResolution(self.config)
elif self.config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "TaylorSeer":
scheduler = WanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada":
scheduler = WanSchedulerAdaCaching(self.config)
elif self.config.feature_caching == "Custom":
scheduler = WanSchedulerCustomCaching(self.config)
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") if self.config.feature_caching == "NoCaching":
scheduler = WanScheduler(self.config)
elif self.config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "TaylorSeer":
scheduler = WanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada":
scheduler = WanSchedulerAdaCaching(self.config)
elif self.config.feature_caching == "Custom":
scheduler = WanSchedulerCustomCaching(self.config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, img): def run_text_encoder(self, text, img):
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.text_encoders = self.load_text_encoder() self.text_encoders = self.load_text_encoder()
text_encoder_output = {} text_encoder_output = {}
n_prompt = self.config.get("negative_prompt", "") n_prompt = self.config.get("negative_prompt", "")
context = self.text_encoders[0].infer([text]) context = self.text_encoders[0].infer([text])
context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""]) context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""])
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.text_encoders[0] del self.text_encoders[0]
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
...@@ -144,11 +191,11 @@ class WanRunner(DefaultRunner): ...@@ -144,11 +191,11 @@ class WanRunner(DefaultRunner):
return text_encoder_output return text_encoder_output
def run_image_encoder(self, img): def run_image_encoder(self, img):
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder() self.image_encoder = self.load_image_encoder()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16) clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16)
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.image_encoder del self.image_encoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
...@@ -179,7 +226,7 @@ class WanRunner(DefaultRunner): ...@@ -179,7 +226,7 @@ class WanRunner(DefaultRunner):
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0] msk = msk.transpose(1, 2)[0]
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_encoder = self.load_vae_encoder() self.vae_encoder = self.load_vae_encoder()
vae_encode_out = self.vae_encoder.encode( vae_encode_out = self.vae_encoder.encode(
[ [
...@@ -193,7 +240,7 @@ class WanRunner(DefaultRunner): ...@@ -193,7 +240,7 @@ class WanRunner(DefaultRunner):
], ],
self.config, self.config,
)[0] )[0]
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder del self.vae_encoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
......
import torch
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
class WanScheduler4ChangingResolution(WanScheduler):
def __init__(self, config):
super().__init__(config)
self.resolution_rate = config.get("resolution_rate", 0.75)
self.changing_resolution_steps = config.get("changing_resolution_steps", config.infer_steps // 2)
def prepare_latents(self, target_shape, dtype=torch.float32):
self.latents = torch.randn(
target_shape[0],
target_shape[1],
int(target_shape[2] * self.resolution_rate) // 2 * 2,
int(target_shape[3] * self.resolution_rate) // 2 * 2,
dtype=dtype,
device=self.device,
generator=self.generator,
)
self.noise_original_resolution = torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
)
def step_post(self):
if self.step_index == self.changing_resolution_steps:
self.step_post_upsample()
else:
super().step_post()
def step_post_upsample(self):
# 1. denoised sample to clean noise
model_output = self.noise_pred.to(torch.float32)
sample = self.latents.to(torch.float32)
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
denoised_sample = x0_pred.to(sample.dtype)
# 2. upsample clean noise to target shape
denoised_sample_5d = denoised_sample.unsqueeze(0) # (C,T,H,W) -> (1,C,T,H,W)
clean_noise = torch.nn.functional.interpolate(denoised_sample_5d, size=(self.config.target_shape[1], self.config.target_shape[2], self.config.target_shape[3]), mode="trilinear")
clean_noise = clean_noise.squeeze(0) # (1,C,T,H,W) -> (C,T,H,W)
# 3. add noise to clean noise
noisy_sample = self.add_noise(clean_noise, self.noise_original_resolution, self.timesteps[self.step_index + 1])
# 4. update latents
self.latents = noisy_sample
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# 5. update timesteps using shift + 2 更激进的去噪
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift + 2)
def add_noise(self, original_samples, noise, timesteps):
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
...@@ -9,9 +9,13 @@ class WanStepDistillScheduler(WanScheduler): ...@@ -9,9 +9,13 @@ class WanStepDistillScheduler(WanScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.denoising_step_list = config.denoising_step_list self.denoising_step_list = config.denoising_step_list
self.infer_steps = self.config.infer_steps self.infer_steps = len(self.denoising_step_list)
self.sample_shift = self.config.sample_shift self.sample_shift = self.config.sample_shift
self.num_train_timesteps = 1000
self.sigma_max = 1.0
self.sigma_min = 0.0
def prepare(self, image_encoder_output): def prepare(self, image_encoder_output):
self.generator = torch.Generator(device=self.device) self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed) self.generator.manual_seed(self.config.seed)
...@@ -23,46 +27,30 @@ class WanStepDistillScheduler(WanScheduler): ...@@ -23,46 +27,30 @@ class WanStepDistillScheduler(WanScheduler):
elif self.config.task in ["i2v"]: elif self.config.task in ["i2v"]:
self.seq_len = self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1] self.seq_len = self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1]
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy() self.set_denoising_timesteps(device=self.device)
sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
self.sigmas = sigmas
self.timesteps = sigmas * self.num_train_timesteps
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.sigmas = self.sigmas.to("cpu")
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
if len(self.denoising_step_list) == self.infer_steps: # 如果denoising_step_list有效既使用
self.set_denoising_timesteps(device=self.device)
else:
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
def set_denoising_timesteps(self, device: Union[str, torch.device] = None): def set_denoising_timesteps(self, device: Union[str, torch.device] = None):
self.timesteps = torch.tensor(self.denoising_step_list, device=device, dtype=torch.int64) sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min)
self.sigmas = torch.cat([self.timesteps / self.num_train_timesteps, torch.tensor([0.0], device=device)]) self.sigmas = torch.linspace(sigma_start, self.sigma_min, self.num_train_timesteps + 1)[:-1]
self.sigmas = self.sigmas.to("cpu") self.sigmas = self.sample_shift * self.sigmas / (1 + (self.sample_shift - 1) * self.sigmas)
self.infer_steps = len(self.timesteps) self.timesteps = self.sigmas * self.num_train_timesteps
self.model_outputs = [ self.denoising_step_index = [self.num_train_timesteps - x for x in self.denoising_step_list]
None, self.timesteps = self.timesteps[self.denoising_step_index].to(device)
] * self.solver_order self.sigmas = self.sigmas[self.denoising_step_index].to("cpu")
self.lower_order_nums = 0
self.last_sample = None
self._begin_index = None
def reset(self): def reset(self):
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.noise_pred = None
self.this_order = None
self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(self.config.target_shape, dtype=torch.float32)
def add_noise(self, original_samples, noise, sigma):
sample = (1 - sigma) * original_samples + sigma * noise
return sample.type_as(noise)
def step_post(self):
flow_pred = self.noise_pred.to(torch.float32)
sigma = self.sigmas[self.step_index].item()
noisy_image_or_video = self.latents.to(torch.float32) - sigma * flow_pred
if self.step_index < self.infer_steps - 1:
sigma = self.sigmas[self.step_index + 1].item()
noisy_image_or_video = self.add_noise(noisy_image_or_video, torch.randn_like(noisy_image_or_video), self.sigmas[self.step_index + 1].item())
self.latents = noisy_image_or_video.to(self.latents.dtype)
...@@ -868,7 +868,7 @@ class WanVAE: ...@@ -868,7 +868,7 @@ class WanVAE:
""" """
videos: A list of videos each with shape [C, T, H, W]. videos: A list of videos each with shape [C, T, H, W].
""" """
if args.cpu_offload: if hasattr(args, "cpu_offload") and args.cpu_offload:
self.to_cuda() self.to_cuda()
if self.use_tiling: if self.use_tiling:
...@@ -876,7 +876,7 @@ class WanVAE: ...@@ -876,7 +876,7 @@ class WanVAE:
else: else:
out = [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos] out = [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos]
if args.cpu_offload: if hasattr(args, "cpu_offload") and args.cpu_offload:
self.to_cpu() self.to_cpu()
return out return out
......
...@@ -90,7 +90,6 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar ...@@ -90,7 +90,6 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar
# Initialize configuration and model # Initialize configuration and model
config = set_config(args) config = set_config(args)
config["mode"] = "server"
logger.info(f"Rank {rank} config: {config}") logger.info(f"Rank {rank} config: {config}")
runner = init_runner(config) runner = init_runner(config)
...@@ -186,6 +185,12 @@ class DistributedInferenceService: ...@@ -186,6 +185,12 @@ class DistributedInferenceService:
self.is_running = False self.is_running = False
def start_distributed_inference(self, args) -> bool: def start_distributed_inference(self, args) -> bool:
if hasattr(args, "lora_path") and args.lora_path:
args.lora_configs = [{"path": args.lora_path, "strength": getattr(args, "lora_strength", 1.0)}]
delattr(args, "lora_path")
if hasattr(args, "lora_strength"):
delattr(args, "lora_strength")
self.args = args self.args = args
if self.is_running: if self.is_running:
logger.warning("Distributed inference service is already running") logger.warning("Distributed inference service is already running")
......
import aiofiles
import asyncio
from PIL import Image
import io
from typing import Union
from pathlib import Path
from loguru import logger
async def load_image_async(path: Union[str, Path]) -> Image.Image:
try:
async with aiofiles.open(path, "rb") as f:
data = await f.read()
return await asyncio.to_thread(lambda: Image.open(io.BytesIO(data)).convert("RGB"))
except Exception as e:
logger.error(f"Failed to load image from {path}: {e}")
raise
async def save_video_async(video_path: Union[str, Path], video_data: bytes):
try:
video_path = Path(video_path)
video_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(video_path, "wb") as f:
await f.write(video_data)
logger.info(f"Video saved to {video_path}")
except Exception as e:
logger.error(f"Failed to save video to {video_path}: {e}")
raise
async def read_text_async(path: Union[str, Path], encoding: str = "utf-8") -> str:
try:
async with aiofiles.open(path, "r", encoding=encoding) as f:
return await f.read()
except Exception as e:
logger.error(f"Failed to read text from {path}: {e}")
raise
async def write_text_async(path: Union[str, Path], content: str, encoding: str = "utf-8"):
try:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(path, "w", encoding=encoding) as f:
await f.write(content)
logger.info(f"Text written to {path}")
except Exception as e:
logger.error(f"Failed to write text to {path}: {e}")
raise
async def exists_async(path: Union[str, Path]) -> bool:
return await asyncio.to_thread(lambda: Path(path).exists())
async def read_bytes_async(path: Union[str, Path]) -> bytes:
try:
async with aiofiles.open(path, "rb") as f:
return await f.read()
except Exception as e:
logger.error(f"Failed to read bytes from {path}: {e}")
raise
async def write_bytes_async(path: Union[str, Path], data: bytes):
try:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(path, "wb") as f:
await f.write(data)
logger.debug(f"Bytes written to {path}")
except Exception as e:
logger.error(f"Failed to write bytes to {path}: {e}")
raise
import torch import torch
from qtorch.quant import float_quantize
from loguru import logger from loguru import logger
try:
from qtorch.quant import float_quantize
except Exception:
logger.warning("qtorch not found, please install qtorch.Please install qtorch (pip install qtorch).")
float_quantize = None
class BaseQuantizer(object): class BaseQuantizer(object):
def __init__(self, bit, symmetric, granularity, **kwargs): def __init__(self, bit, symmetric, granularity, **kwargs):
......
...@@ -17,8 +17,7 @@ def get_default_config(): ...@@ -17,8 +17,7 @@ def get_default_config():
"teacache_thresh": 0.26, "teacache_thresh": 0.26,
"use_ret_steps": False, "use_ret_steps": False,
"use_bfloat16": True, "use_bfloat16": True,
"lora_path": None, "lora_configs": None, # List of dicts with 'path' and 'strength' keys
"strength_model": 1.0,
"mm_config": {}, "mm_config": {},
"use_prompt_enhancer": False, "use_prompt_enhancer": False,
} }
......
...@@ -58,6 +58,14 @@ def cache_video( ...@@ -58,6 +58,14 @@ def cache_video(
value_range=(-1, 1), value_range=(-1, 1),
retry=5, retry=5,
): ):
save_dir = os.path.dirname(save_file)
try:
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
except Exception as e:
logger.error(f"Failed to create directory: {save_dir}, error: {e}")
return None
cache_file = save_file cache_file = save_file
# save to cache # save to cache
......
...@@ -94,6 +94,7 @@ set(SOURCES ...@@ -94,6 +94,7 @@ set(SOURCES
"csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu" "csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu"
"csrc/gemm/nvfp4_quant_kernels_sm120.cu" "csrc/gemm/nvfp4_quant_kernels_sm120.cu"
"csrc/gemm/mxfp8_quant_kernels_sm120.cu" "csrc/gemm/mxfp8_quant_kernels_sm120.cu"
"csrc/gemm/mxfp6_quant_kernels_sm120.cu"
"csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu" "csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu" "csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/common_extension.cc" "csrc/common_extension.cc"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment