Commit 514ea716 authored by helloyongyang's avatar helloyongyang
Browse files

remove split server & fix some bugs

parent a23bef13
import asyncio
from typing import Callable, Any, Optional
from concurrent.futures import ThreadPoolExecutor
class AsyncWrapper:
def __init__(self, runner, max_workers: Optional[int] = None):
self.runner = runner
self.executor = ThreadPoolExecutor(max_workers=max_workers)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.executor:
self.executor.shutdown(wait=True)
async def run_in_executor(self, func: Callable, *args, **kwargs) -> Any:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, func, *args, **kwargs)
async def run_input_encoder(self):
if self.runner.config["mode"] == "split_server":
if self.runner.config["task"] == "i2v":
return await self.runner._run_input_encoder_server_i2v()
else:
return await self.runner._run_input_encoder_server_t2v()
else:
if self.runner.config["task"] == "i2v":
return await self.run_in_executor(self.runner._run_input_encoder_local_i2v)
else:
return await self.run_in_executor(self.runner._run_input_encoder_local_t2v)
async def run_dit(self, kwargs):
if self.runner.config["mode"] == "split_server":
return await self.runner._run_dit_server(kwargs)
else:
return await self.run_in_executor(self.runner._run_dit_local, kwargs)
async def run_vae_decoder(self, latents, generator):
if self.runner.config["mode"] == "split_server":
return await self.runner._run_vae_decoder_server(latents, generator)
else:
return await self.run_in_executor(self.runner._run_vae_decoder_local, latents, generator)
async def run_prompt_enhancer(self):
if self.runner.config["use_prompt_enhancer"]:
return await self.run_in_executor(self.runner.post_prompt_enhancer)
return None
async def save_video(self, images):
return await self.run_in_executor(self.runner.save_video, images)
import asyncio
import gc import gc
import aiohttp
import requests import requests
from requests.exceptions import RequestException from requests.exceptions import RequestException
import torch import torch
...@@ -13,7 +11,6 @@ from lightx2v.utils.generate_task_id import generate_task_id ...@@ -13,7 +11,6 @@ from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter
from loguru import logger from loguru import logger
from .async_wrapper import AsyncWrapper
from .base_runner import BaseRunner from .base_runner import BaseRunner
...@@ -33,21 +30,14 @@ class DefaultRunner(BaseRunner): ...@@ -33,21 +30,14 @@ class DefaultRunner(BaseRunner):
def init_modules(self): def init_modules(self):
logger.info("Initializing runner modules...") logger.info("Initializing runner modules...")
if self.config["mode"] == "split_server": if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.tensor_transporter = TensorTransporter() self.load_model()
self.image_transporter = ImageTransporter() self.run_dit = self._run_dit_local
if not self.check_sub_servers("dit"): self.run_vae_decoder = self._run_vae_decoder_local
raise ValueError("No dit server available") if self.config["task"] == "i2v":
if not self.check_sub_servers("text_encoders"): self.run_input_encoder = self._run_input_encoder_local_i2v
raise ValueError("No text encoder server available")
if self.config["task"] == "i2v":
if not self.check_sub_servers("image_encoder"):
raise ValueError("No image encoder server available")
if not self.check_sub_servers("vae_model"):
raise ValueError("No vae server available")
else: else:
if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False): self.run_input_encoder = self._run_input_encoder_local_t2v
self.load_model()
def set_init_device(self): def set_init_device(self):
if self.config["parallel_attn_type"]: if self.config["parallel_attn_type"]:
...@@ -123,14 +113,13 @@ class DefaultRunner(BaseRunner): ...@@ -123,14 +113,13 @@ class DefaultRunner(BaseRunner):
return self.model.scheduler.latents, self.model.scheduler.generator return self.model.scheduler.latents, self.model.scheduler.generator
async def run_step(self, step_index=0): def run_step(self, step_index=0):
async with AsyncWrapper(self) as wrapper: self.init_scheduler()
self.init_scheduler() self.inputs = self.run_input_encoder()
self.inputs = await wrapper.run_input_encoder() self.model.scheduler.prepare(self.inputs["image_encoder_output"])
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) self.model.scheduler.step_pre(step_index=step_index)
self.model.scheduler.step_pre(step_index=step_index) self.model.infer(self.inputs)
self.model.infer(self.inputs) self.model.scheduler.step_post()
self.model.scheduler.step_post()
def end_run(self): def end_run(self):
self.model.scheduler.clear() self.model.scheduler.clear()
...@@ -194,52 +183,6 @@ class DefaultRunner(BaseRunner): ...@@ -194,52 +183,6 @@ class DefaultRunner(BaseRunner):
if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0): if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
self.save_video_func(images) self.save_video_func(images)
async def post_task(self, task_type, urls, message, device="cuda", max_retries=3, timeout=30):
for attempt in range(max_retries):
for url in urls:
try:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
try:
async with session.get(f"{url}/v1/local/{task_type}/generate/service_status") as response:
if response.status != 200:
logger.warning(f"Service {url} returned status {response.status}")
continue
status = await response.json()
except asyncio.TimeoutError:
logger.warning(f"Timeout checking status for {url}")
continue
except Exception as e:
logger.warning(f"Error checking status for {url}: {e}")
continue
if status.get("service_status") == "idle":
try:
async with session.post(f"{url}/v1/local/{task_type}/generate", json=message) as response:
if response.status == 200:
result = await response.json()
if result.get("kwargs") is not None:
for k, v in result["kwargs"].items():
setattr(self.config, k, v)
return self.tensor_transporter.load_tensor(result["output"], device)
else:
logger.warning(f"Task failed with status {response.status} for {url}")
except asyncio.TimeoutError:
logger.warning(f"Timeout posting task to {url}")
except Exception as e:
logger.error(f"Error posting task to {url}: {e}")
except aiohttp.ClientError as e:
logger.warning(f"Client error for {url}: {e}")
except Exception as e:
logger.error(f"Unexpected error for {url}: {e}")
if attempt < max_retries - 1:
wait_time = min(2**attempt, 10)
logger.info(f"Retrying in {wait_time} seconds... (attempt {attempt + 1}/{max_retries})")
await asyncio.sleep(wait_time)
raise RuntimeError(f"Failed to complete task {task_type} after {max_retries} attempts")
def post_prompt_enhancer(self): def post_prompt_enhancer(self):
while True: while True:
for url in self.config["sub_servers"]["prompt_enhancer"]: for url in self.config["sub_servers"]["prompt_enhancer"]:
...@@ -256,138 +199,23 @@ class DefaultRunner(BaseRunner): ...@@ -256,138 +199,23 @@ class DefaultRunner(BaseRunner):
logger.info(f"Enhanced prompt: {enhanced_prompt}") logger.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt return enhanced_prompt
async def post_encoders_i2v(self, prompt, img=None, n_prompt=None, i2v=False): def run_pipeline(self, save_video=True):
tasks = [] if self.config["use_prompt_enhancer"]:
img_byte = self.image_transporter.prepare_image(img) self.config["prompt_enhanced"] = self.post_prompt_enhancer()
tasks.append(
asyncio.create_task(
self.post_task(
task_type="image_encoder",
urls=self.config["sub_servers"]["image_encoder"],
message={"task_id": generate_task_id(), "img": img_byte},
device="cuda",
)
)
)
tasks.append(
asyncio.create_task(
self.post_task(
task_type="vae_model/encoder",
urls=self.config["sub_servers"]["vae_model"],
message={"task_id": generate_task_id(), "img": img_byte},
device="cuda",
)
)
)
tasks.append(
asyncio.create_task(
self.post_task(
task_type="text_encoders",
urls=self.config["sub_servers"]["text_encoders"],
message={
"task_id": generate_task_id(),
"text": prompt,
"img": img_byte,
"n_prompt": n_prompt,
},
device="cuda",
)
)
)
results = await asyncio.gather(*tasks)
# clip_encoder, vae_encoder, text_encoders
return results[0], results[1], results[2]
async def post_encoders_t2v(self, prompt, n_prompt=None):
tasks = []
tasks.append(
asyncio.create_task(
self.post_task(
task_type="text_encoders",
urls=self.config["sub_servers"]["text_encoders"],
message={
"task_id": generate_task_id(),
"text": prompt,
"img": None,
"n_prompt": n_prompt,
},
device="cuda",
)
)
)
results = await asyncio.gather(*tasks)
# text_encoders
return results[0]
async def _run_input_encoder_server_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
n_prompt = self.config.get("negative_prompt", "")
img = Image.open(self.config["image_path"]).convert("RGB")
(
clip_encoder_out,
vae_encode_out,
text_encoder_output,
) = await self.post_encoders_i2v(prompt, img, n_prompt)
torch.cuda.empty_cache()
gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
async def _run_input_encoder_server_t2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
n_prompt = self.config.get("negative_prompt", "")
text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt)
torch.cuda.empty_cache()
gc.collect()
return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": None,
}
async def _run_dit_server(self, kwargs):
if self.inputs.get("image_encoder_output", None) is not None:
self.inputs["image_encoder_output"].pop("img", None)
dit_output = await self.post_task(
task_type="dit",
urls=self.config["sub_servers"]["dit"],
message={
"task_id": generate_task_id(),
"inputs": self.tensor_transporter.prepare_tensor(self.inputs),
"kwargs": self.tensor_transporter.prepare_tensor(kwargs),
},
device="cuda",
)
return dit_output, None
async def _run_vae_decoder_server(self, latents, generator):
images = await self.post_task(
task_type="vae_model/decoder",
urls=self.config["sub_servers"]["vae_model"],
message={
"task_id": generate_task_id(),
"latents": self.tensor_transporter.prepare_tensor(latents),
},
device="cpu",
)
return images
async def run_pipeline(self, save_video=True):
async with AsyncWrapper(self) as wrapper:
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = await wrapper.run_prompt_enhancer()
self.inputs = await wrapper.run_input_encoder() self.inputs = self.run_input_encoder()
kwargs = self.set_target_shape() kwargs = self.set_target_shape()
latents, generator = await wrapper.run_dit(kwargs) latents, generator = self.run_dit(kwargs)
images = await wrapper.run_vae_decoder(latents, generator) images = self.run_vae_decoder(latents, generator)
if save_video: if save_video:
await wrapper.save_video(images) self.save_video(images)
del latents, generator del latents, generator
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return images return images
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment