Unverified Commit 61dd69ca authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Update lightx2v_platform (#553)

parent aa627f77
......@@ -23,6 +23,8 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
@RUNNER_REGISTER("hunyuan_video_1.5")
class HunyuanVideo15Runner(DefaultRunner):
......@@ -349,8 +351,7 @@ class HunyuanVideo15Runner(DefaultRunner):
self.model_sr.scheduler.step_post()
del self.inputs_sr
torch_ext_module = getattr(torch, AI_DEVICE)
torch_ext_module.empty_cache()
torch_device_module.empty_cache()
self.config_sr["is_sr_running"] = False
return self.model_sr.scheduler.latents
......@@ -371,8 +372,7 @@ class HunyuanVideo15Runner(DefaultRunner):
siglip_output = torch.zeros(1, self.vision_num_semantic_tokens, self.config["hidden_size"], dtype=torch.bfloat16).to(AI_DEVICE)
siglip_mask = torch.zeros(1, self.vision_num_semantic_tokens, dtype=torch.bfloat16, device=torch.device(AI_DEVICE))
torch_ext_module = getattr(torch, AI_DEVICE)
torch_ext_module.empty_cache()
torch_device_module.empty_cache()
gc.collect()
return {
"text_encoder_output": text_encoder_output,
......@@ -399,8 +399,7 @@ class HunyuanVideo15Runner(DefaultRunner):
siglip_output, siglip_mask = self.run_image_encoder(img_ori) if self.config.get("use_image_encoder", True) else None
cond_latents = self.run_vae_encoder(img_ori)
text_encoder_output = self.run_text_encoder(self.input_info)
torch_ext_module = getattr(torch, AI_DEVICE)
torch_ext_module.empty_cache()
torch_device_module.empty_cache()
gc.collect()
return {
"text_encoder_output": text_encoder_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment