Commit 7a35c418 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Refactor/async (#116)

* Add asynchronous processing capabilities with AsyncWrapper and BaseRunner class implementation

* Refactor async I/O functions for consistency and improved readability

* Add progress callback functionality to DefaultRunner and improve logging during execution

* Enhance DefaultRunner's run_pipeline method to include optional video saving functionality

* Remove unnecessary deletion of inputs in DefaultRunner's run_pipeline method to streamline resource management
parent 0f408885
import asyncio
from typing import Callable, Any, Optional
from concurrent.futures import ThreadPoolExecutor
class AsyncWrapper:
def __init__(self, runner, max_workers: Optional[int] = None):
self.runner = runner
self.executor = ThreadPoolExecutor(max_workers=max_workers)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.executor:
self.executor.shutdown(wait=True)
async def run_in_executor(self, func: Callable, *args, **kwargs) -> Any:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, func, *args, **kwargs)
async def run_input_encoder(self):
if self.runner.config["mode"] == "split_server":
if self.runner.config["task"] == "i2v":
return await self.runner._run_input_encoder_server_i2v()
else:
return await self.runner._run_input_encoder_server_t2v()
else:
if self.runner.config["task"] == "i2v":
return await self.run_in_executor(self.runner._run_input_encoder_local_i2v)
else:
return await self.run_in_executor(self.runner._run_input_encoder_local_t2v)
async def run_dit(self, kwargs):
if self.runner.config["mode"] == "split_server":
return await self.runner._run_dit_server(kwargs)
else:
return await self.run_in_executor(self.runner._run_dit_local, kwargs)
async def run_vae_decoder(self, latents, generator):
if self.runner.config["mode"] == "split_server":
return await self.runner._run_vae_decoder_server(latents, generator)
else:
return await self.run_in_executor(self.runner._run_vae_decoder_local, latents, generator)
async def run_prompt_enhancer(self):
if self.runner.config["use_prompt_enhancer"]:
return await self.run_in_executor(self.runner.post_prompt_enhancer)
return None
async def save_video(self, images):
return await self.run_in_executor(self.runner.save_video, images)
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple, Optional
from lightx2v.utils.utils import save_videos_grid
class BaseRunner(ABC):
"""Abstract base class for all Runners
Defines interface methods that all subclasses must implement
"""
def __init__(self, config: Dict[str, Any]):
self.config = config
@abstractmethod
def load_transformer(self):
"""Load transformer model
Returns:
Loaded model instance
"""
pass
@abstractmethod
def load_text_encoder(self):
"""Load text encoder
Returns:
Text encoder instance or list of instances
"""
pass
@abstractmethod
def load_image_encoder(self):
"""Load image encoder
Returns:
Image encoder instance
"""
pass
@abstractmethod
def load_vae(self) -> Tuple[Any, Any]:
"""Load VAE encoder and decoder
Returns:
Tuple[vae_encoder, vae_decoder]: VAE encoder and decoder instances
"""
pass
@abstractmethod
def run_image_encoder(self, img):
"""Run image encoder
Args:
img: Input image
Returns:
Image encoding result
"""
pass
@abstractmethod
def run_vae_encoder(self, img):
"""Run VAE encoder
Args:
img: Input image
Returns:
VAE encoding result and additional parameters
"""
pass
@abstractmethod
def run_text_encoder(self, prompt: str, img: Optional[Any] = None):
"""Run text encoder
Args:
prompt: Input text prompt
img: Optional input image (for some models)
Returns:
Text encoding result
"""
pass
@abstractmethod
def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
"""Combine encoder outputs for i2v task
Args:
clip_encoder_out: CLIP encoder output
vae_encode_out: VAE encoder output
text_encoder_output: Text encoder output
img: Original image
Returns:
Combined encoder output dictionary
"""
pass
@abstractmethod
def init_scheduler(self):
"""Initialize scheduler"""
pass
def set_target_shape(self) -> Dict[str, Any]:
"""Set target shape
Subclasses can override this method to provide specific implementation
Returns:
Dictionary containing target shape information
"""
return {}
def save_video_func(self, images):
"""Save video implementation
Subclasses can override this method to customize save logic
Args:
images: Image sequence to save
"""
save_videos_grid(images, self.config.get("save_video_path", "./output.mp4"), n_rows=1, fps=self.config.get("fps", 8))
def load_vae_decoder(self):
"""Load VAE decoder
Default implementation: get decoder from load_vae method
Subclasses can override this method to provide different loading logic
Returns:
VAE decoder instance
"""
if not hasattr(self, "vae_decoder") or self.vae_decoder is None:
_, self.vae_decoder = self.load_vae()
return self.vae_decoder
...@@ -13,12 +13,15 @@ from lightx2v.utils.generate_task_id import generate_task_id ...@@ -13,12 +13,15 @@ from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter
from loguru import logger from loguru import logger
from .async_wrapper import AsyncWrapper
from .base_runner import BaseRunner
class DefaultRunner: class DefaultRunner(BaseRunner):
def __init__(self, config): def __init__(self, config):
self.config = config super().__init__(config)
self.has_prompt_enhancer = False self.has_prompt_enhancer = False
self.progress_callback = None
if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None: if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
self.has_prompt_enhancer = True self.has_prompt_enhancer = True
if not self.check_sub_servers("prompt_enhancer"): if not self.check_sub_servers("prompt_enhancer"):
...@@ -42,21 +45,9 @@ class DefaultRunner: ...@@ -42,21 +45,9 @@ class DefaultRunner:
raise ValueError("No image encoder server available") raise ValueError("No image encoder server available")
if not self.check_sub_servers("vae_model"): if not self.check_sub_servers("vae_model"):
raise ValueError("No vae server available") raise ValueError("No vae server available")
self.run_dit = self.run_dit_server
self.run_vae_decoder = self.run_vae_decoder_server
if self.config["task"] == "i2v":
self.run_input_encoder = self.run_input_encoder_server_i2v
else:
self.run_input_encoder = self.run_input_encoder_server_t2v
else: else:
if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False): if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.load_model() self.load_model()
self.run_dit = self.run_dit_local
self.run_vae_decoder = self.run_vae_decoder_local
if self.config["task"] == "i2v":
self.run_input_encoder = self.run_input_encoder_local_i2v
else:
self.run_input_encoder = self.run_input_encoder_local_t2v
def set_init_device(self): def set_init_device(self):
if self.config["parallel_attn_type"]: if self.config["parallel_attn_type"]:
...@@ -110,9 +101,13 @@ class DefaultRunner: ...@@ -110,9 +101,13 @@ class DefaultRunner:
# self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5)) # self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
# self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5)) # self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
def set_progress_callback(self, callback):
self.progress_callback = callback
def run(self): def run(self):
for step_index in range(self.model.scheduler.infer_steps): total_steps = self.model.scheduler.infer_steps
logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}") for step_index in range(total_steps):
logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
with ProfilingContext4Debug("step_pre"): with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index) self.model.scheduler.step_pre(step_index=step_index)
...@@ -123,15 +118,19 @@ class DefaultRunner: ...@@ -123,15 +118,19 @@ class DefaultRunner:
with ProfilingContext4Debug("step_post"): with ProfilingContext4Debug("step_post"):
self.model.scheduler.step_post() self.model.scheduler.step_post()
if self.progress_callback:
self.progress_callback(step_index + 1, total_steps)
return self.model.scheduler.latents, self.model.scheduler.generator return self.model.scheduler.latents, self.model.scheduler.generator
async def run_step(self, step_index=0): async def run_step(self, step_index=0):
self.init_scheduler() async with AsyncWrapper(self) as wrapper:
await self.run_input_encoder() self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) self.inputs = await wrapper.run_input_encoder()
self.model.scheduler.step_pre(step_index=step_index) self.model.scheduler.prepare(self.inputs["image_encoder_output"])
self.model.infer(self.inputs) self.model.scheduler.step_pre(step_index=step_index)
self.model.scheduler.step_post() self.model.infer(self.inputs)
self.model.scheduler.step_post()
def end_run(self): def end_run(self):
self.model.scheduler.clear() self.model.scheduler.clear()
...@@ -148,7 +147,7 @@ class DefaultRunner: ...@@ -148,7 +147,7 @@ class DefaultRunner:
gc.collect() gc.collect()
@ProfilingContext("Run Encoders") @ProfilingContext("Run Encoders")
async def run_input_encoder_local_i2v(self): def _run_input_encoder_local_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img = Image.open(self.config["image_path"]).convert("RGB") img = Image.open(self.config["image_path"]).convert("RGB")
clip_encoder_out = self.run_image_encoder(img) clip_encoder_out = self.run_image_encoder(img)
...@@ -159,7 +158,7 @@ class DefaultRunner: ...@@ -159,7 +158,7 @@ class DefaultRunner:
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img) return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
@ProfilingContext("Run Encoders") @ProfilingContext("Run Encoders")
async def run_input_encoder_local_t2v(self): def _run_input_encoder_local_t2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt, None) text_encoder_output = self.run_text_encoder(prompt, None)
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -167,7 +166,7 @@ class DefaultRunner: ...@@ -167,7 +166,7 @@ class DefaultRunner:
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None} return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}
@ProfilingContext("Run DiT") @ProfilingContext("Run DiT")
async def run_dit_local(self, kwargs): def _run_dit_local(self, kwargs):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer() self.model = self.load_transformer()
self.init_scheduler() self.init_scheduler()
...@@ -177,7 +176,7 @@ class DefaultRunner: ...@@ -177,7 +176,7 @@ class DefaultRunner:
return latents, generator return latents, generator
@ProfilingContext("Run VAE Decoder") @ProfilingContext("Run VAE Decoder")
async def run_vae_decoder_local(self, latents, generator): def _run_vae_decoder_local(self, latents, generator):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder() self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents, generator=generator, config=self.config) images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
...@@ -192,20 +191,51 @@ class DefaultRunner: ...@@ -192,20 +191,51 @@ class DefaultRunner:
if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0): if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
self.save_video_func(images) self.save_video_func(images)
async def post_task(self, task_type, urls, message, device="cuda"): async def post_task(self, task_type, urls, message, device="cuda", max_retries=3, timeout=30):
while True: for attempt in range(max_retries):
for url in urls: for url in urls:
async with aiohttp.ClientSession() as session: try:
async with session.get(f"{url}/v1/local/{task_type}/generate/service_status") as response: async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
status = await response.json() try:
if status["service_status"] == "idle": async with session.get(f"{url}/v1/local/{task_type}/generate/service_status") as response:
async with session.post(f"{url}/v1/local/{task_type}/generate", json=message) as response: if response.status != 200:
result = await response.json() logger.warning(f"Service {url} returned status {response.status}")
if result["kwargs"] is not None: continue
for k, v in result["kwargs"].items(): status = await response.json()
setattr(self.config, k, v) except asyncio.TimeoutError:
return self.tensor_transporter.load_tensor(result["output"], device) logger.warning(f"Timeout checking status for {url}")
await asyncio.sleep(0.1) continue
except Exception as e:
logger.warning(f"Error checking status for {url}: {e}")
continue
if status.get("service_status") == "idle":
try:
async with session.post(f"{url}/v1/local/{task_type}/generate", json=message) as response:
if response.status == 200:
result = await response.json()
if result.get("kwargs") is not None:
for k, v in result["kwargs"].items():
setattr(self.config, k, v)
return self.tensor_transporter.load_tensor(result["output"], device)
else:
logger.warning(f"Task failed with status {response.status} for {url}")
except asyncio.TimeoutError:
logger.warning(f"Timeout posting task to {url}")
except Exception as e:
logger.error(f"Error posting task to {url}: {e}")
except aiohttp.ClientError as e:
logger.warning(f"Client error for {url}: {e}")
except Exception as e:
logger.error(f"Unexpected error for {url}: {e}")
if attempt < max_retries - 1:
wait_time = min(2**attempt, 10)
logger.info(f"Retrying in {wait_time} seconds... (attempt {attempt + 1}/{max_retries})")
await asyncio.sleep(wait_time)
raise RuntimeError(f"Failed to complete task {task_type} after {max_retries} attempts")
def post_prompt_enhancer(self): def post_prompt_enhancer(self):
while True: while True:
...@@ -256,7 +286,7 @@ class DefaultRunner: ...@@ -256,7 +286,7 @@ class DefaultRunner:
# text_encoders # text_encoders
return results[0] return results[0]
async def run_input_encoder_server_i2v(self): async def _run_input_encoder_server_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
n_prompt = self.config.get("negative_prompt", "") n_prompt = self.config.get("negative_prompt", "")
img = Image.open(self.config["image_path"]).convert("RGB") img = Image.open(self.config["image_path"]).convert("RGB")
...@@ -265,7 +295,7 @@ class DefaultRunner: ...@@ -265,7 +295,7 @@ class DefaultRunner:
gc.collect() gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img) return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
async def run_input_encoder_server_t2v(self): async def _run_input_encoder_server_t2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
n_prompt = self.config.get("negative_prompt", "") n_prompt = self.config.get("negative_prompt", "")
text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt) text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt)
...@@ -273,7 +303,7 @@ class DefaultRunner: ...@@ -273,7 +303,7 @@ class DefaultRunner:
gc.collect() gc.collect()
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None} return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}
async def run_dit_server(self, kwargs): async def _run_dit_server(self, kwargs):
if self.inputs.get("image_encoder_output", None) is not None: if self.inputs.get("image_encoder_output", None) is not None:
self.inputs["image_encoder_output"].pop("img", None) self.inputs["image_encoder_output"].pop("img", None)
dit_output = await self.post_task( dit_output = await self.post_task(
...@@ -284,7 +314,7 @@ class DefaultRunner: ...@@ -284,7 +314,7 @@ class DefaultRunner:
) )
return dit_output, None return dit_output, None
async def run_vae_decoder_server(self, latents, generator): async def _run_vae_decoder_server(self, latents, generator):
images = await self.post_task( images = await self.post_task(
task_type="vae_model/decoder", task_type="vae_model/decoder",
urls=self.config["sub_servers"]["vae_model"], urls=self.config["sub_servers"]["vae_model"],
...@@ -293,14 +323,24 @@ class DefaultRunner: ...@@ -293,14 +323,24 @@ class DefaultRunner:
) )
return images return images
async def run_pipeline(self): async def run_pipeline(self, save_video=True):
if self.config["use_prompt_enhancer"]: async with AsyncWrapper(self) as wrapper:
self.config["prompt_enhanced"] = self.post_prompt_enhancer() if self.config["use_prompt_enhancer"]:
self.inputs = await self.run_input_encoder() self.config["prompt_enhanced"] = await wrapper.run_prompt_enhancer()
kwargs = self.set_target_shape()
latents, generator = await self.run_dit(kwargs) self.inputs = await wrapper.run_input_encoder()
images = await self.run_vae_decoder(latents, generator)
self.save_video(images) kwargs = self.set_target_shape()
del latents, generator, images
torch.cuda.empty_cache() latents, generator = await wrapper.run_dit(kwargs)
gc.collect()
images = await wrapper.run_vae_decoder(latents, generator)
if save_video:
await wrapper.save_video(images)
del latents, generator
torch.cuda.empty_cache()
gc.collect()
return images
import aiofiles
import asyncio
from PIL import Image
import io
from typing import Union
from pathlib import Path
from loguru import logger
async def load_image_async(path: Union[str, Path]) -> Image.Image:
try:
async with aiofiles.open(path, "rb") as f:
data = await f.read()
return await asyncio.to_thread(lambda: Image.open(io.BytesIO(data)).convert("RGB"))
except Exception as e:
logger.error(f"Failed to load image from {path}: {e}")
raise
async def save_video_async(video_path: Union[str, Path], video_data: bytes):
try:
video_path = Path(video_path)
video_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(video_path, "wb") as f:
await f.write(video_data)
logger.info(f"Video saved to {video_path}")
except Exception as e:
logger.error(f"Failed to save video to {video_path}: {e}")
raise
async def read_text_async(path: Union[str, Path], encoding: str = "utf-8") -> str:
try:
async with aiofiles.open(path, "r", encoding=encoding) as f:
return await f.read()
except Exception as e:
logger.error(f"Failed to read text from {path}: {e}")
raise
async def write_text_async(path: Union[str, Path], content: str, encoding: str = "utf-8"):
try:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(path, "w", encoding=encoding) as f:
await f.write(content)
logger.info(f"Text written to {path}")
except Exception as e:
logger.error(f"Failed to write text to {path}: {e}")
raise
async def exists_async(path: Union[str, Path]) -> bool:
return await asyncio.to_thread(lambda: Path(path).exists())
async def read_bytes_async(path: Union[str, Path]) -> bytes:
try:
async with aiofiles.open(path, "rb") as f:
return await f.read()
except Exception as e:
logger.error(f"Failed to read bytes from {path}: {e}")
raise
async def write_bytes_async(path: Union[str, Path], data: bytes):
try:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(path, "wb") as f:
await f.write(data)
logger.debug(f"Bytes written to {path}")
except Exception as e:
logger.error(f"Failed to write bytes to {path}: {e}")
raise
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment