Unverified Commit b50498fa authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Add lightx2v_platform (#541)

parent 31da6925
......@@ -2,18 +2,18 @@ import torch
from transformers import AutoFeatureExtractor, AutoModel
from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE
class SekoAudioEncoderModel:
def __init__(self, model_path, audio_sr, cpu_offload, run_device):
def __init__(self, model_path, audio_sr, cpu_offload):
self.model_path = model_path
self.audio_sr = audio_sr
self.cpu_offload = cpu_offload
if self.cpu_offload:
self.device = torch.device("cpu")
else:
self.device = torch.device(run_device)
self.run_device = run_device
self.device = torch.device(AI_DEVICE)
self.load()
def load(self):
......@@ -27,13 +27,13 @@ class SekoAudioEncoderModel:
self.audio_feature_encoder = self.audio_feature_encoder.to("cpu")
def to_cuda(self):
self.audio_feature_encoder = self.audio_feature_encoder.to(self.run_device)
self.audio_feature_encoder = self.audio_feature_encoder.to(AI_DEVICE)
@torch.no_grad()
def infer(self, audio_segment):
audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(self.run_device).to(dtype=GET_DTYPE())
audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(AI_DEVICE).to(dtype=GET_DTYPE())
if self.cpu_offload:
self.audio_feature_encoder = self.audio_feature_encoder.to(self.run_device)
self.audio_feature_encoder = self.audio_feature_encoder.to(AI_DEVICE)
audio_feat = self.audio_feature_encoder(audio_feat, return_dict=True).last_hidden_state
if self.cpu_offload:
self.audio_feature_encoder = self.audio_feature_encoder.to("cpu")
......
......@@ -24,8 +24,8 @@ from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
SglQuantLinearFp8, # noqa E402
TorchaoQuantLinearInt8, # noqa E402
VllmQuantLinearInt8, # noqa E402,
MluQuantLinearInt8,
)
from lightx2v_platform.ops.mm.cambricon_mlu.q_linear import MluQuantLinearInt8 # noqa E402
from lightx2v.models.input_encoders.hf.wan.t5.tokenizer import HuggingfaceTokenizer # noqa E402
from lightx2v.utils.envs import * # noqa E402
from lightx2v.utils.registry_factory import ( # noqa E402
......@@ -34,6 +34,7 @@ from lightx2v.utils.registry_factory import ( # noqa E402
RMS_WEIGHT_REGISTER, # noqa E402
)
from lightx2v.utils.utils import load_weights # noqa E402
from lightx2v_platform.base.global_var import AI_DEVICE # noqa E402
__all__ = [
"T5Model",
......@@ -745,7 +746,6 @@ class T5EncoderModel:
text_len,
dtype=torch.bfloat16,
device=torch.device("cuda"),
run_device=torch.device("cuda"),
checkpoint_path=None,
tokenizer_path=None,
shard_fn=None,
......@@ -758,7 +758,6 @@ class T5EncoderModel:
self.text_len = text_len
self.dtype = dtype
self.device = device
self.run_device = run_device
if t5_quantized_ckpt is not None and t5_quantized:
self.checkpoint_path = t5_quantized_ckpt
else:
......@@ -807,8 +806,8 @@ class T5EncoderModel:
def infer(self, texts):
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
ids = ids.to(self.run_device)
mask = mask.to(self.run_device)
ids = ids.to(AI_DEVICE)
mask = mask.to(AI_DEVICE)
seq_lens = mask.gt(0).sum(dim=1).long()
with torch.no_grad():
......
......@@ -10,8 +10,10 @@ from loguru import logger
# from lightx2v.attentions import attention
from lightx2v.common.ops.attn import TorchSDPAWeight
from lightx2v.models.input_encoders.hf.q_linear import MluQuantLinearInt8, Q8FQuantLinearFp8, Q8FQuantLinearInt8, SglQuantLinearFp8, TorchaoQuantLinearInt8, VllmQuantLinearInt8
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, SglQuantLinearFp8, TorchaoQuantLinearInt8, VllmQuantLinearInt8
from lightx2v.utils.utils import load_weights
from lightx2v_platform.base.global_var import AI_DEVICE
from lightx2v_platform.ops.mm.cambricon_mlu.q_linear import MluQuantLinearInt8
__all__ = [
"XLMRobertaCLIP",
......@@ -426,9 +428,8 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False, run_device=torch.device("cuda")):
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False):
self.dtype = dtype
self.run_device = run_device
self.quantized = clip_quantized
self.cpu_offload = cpu_offload
self.use_31_block = use_31_block
......@@ -462,7 +463,7 @@ class CLIPModel:
return out
def to_cuda(self):
self.model = self.model.to(self.run_device)
self.model = self.model.to(AI_DEVICE)
def to_cpu(self):
self.model = self.model.cpu()
......@@ -5,6 +5,7 @@ import torch
from einops import rearrange
from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE
from .attn_no_pad import flash_attn_no_pad, flash_attn_no_pad_v3, sage_attn_no_pad_v2
from .module_io import HunyuanVideo15InferModuleOutput
......@@ -68,7 +69,6 @@ class HunyuanVideo15PreInfer:
self.heads_num = config["heads_num"]
self.frequency_embedding_size = 256
self.max_period = 10000
self.run_device = torch.device(self.config.get("run_device", "cuda"))
def set_scheduler(self, scheduler):
self.scheduler = scheduler
......@@ -155,7 +155,7 @@ class HunyuanVideo15PreInfer:
byt5_txt = byt5_txt + weights.cond_type_embedding.apply(torch.ones_like(byt5_txt[:, :, 0], device=byt5_txt.device, dtype=torch.long))
txt, text_mask = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask, zero_feat=True)
siglip_output = siglip_output + weights.cond_type_embedding.apply(2 * torch.ones_like(siglip_output[:, :, 0], dtype=torch.long, device=self.run_device))
siglip_output = siglip_output + weights.cond_type_embedding.apply(2 * torch.ones_like(siglip_output[:, :, 0], dtype=torch.long, device=AI_DEVICE))
txt, text_mask = self.reorder_txt_token(siglip_output, txt, siglip_mask, text_mask)
txt = txt[:, : text_mask.sum(), :]
......
......@@ -10,6 +10,7 @@ except Exception as e:
apply_rope_with_cos_sin_cache_inplace = None
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v_platform.base.global_var import AI_DEVICE
from .module_io import HunyuanVideo15ImgBranchOutput, HunyuanVideo15TxtBranchOutput
from .triton_ops import fuse_scale_shift_kernel
......@@ -100,7 +101,6 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
self.config = config
self.double_blocks_num = config["mm_double_blocks_depth"]
self.heads_num = config["heads_num"]
self.run_device = torch.device(self.config.get("run_device", "cuda"))
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
self.seq_p_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False)
......@@ -222,7 +222,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
key = torch.cat([img_k, txt_k], dim=1)
value = torch.cat([img_v, txt_v], dim=1)
seqlen = query.shape[1]
cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu").to(self.run_device, non_blocking=True)
cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True)
if self.config["seq_parallel"]:
attn_out = weights.self_attention_parallel.apply(
......
......@@ -176,12 +176,12 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.device.type != "cpu" and dist.is_initialized():
device = torch.device("{}:{}".format(self.device.type, dist.get_rank()))
if self.config["parallel"]:
device = dist.get_rank()
else:
device = self.device
device = str(self.device)
with safe_open(file_path, framework="pt", device=str(device)) as f:
with safe_open(file_path, framework="pt", device=device) as f:
return {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
for key in f.keys()
......
......@@ -111,7 +111,7 @@ def apply_attn(block_weight, hidden_states, encoder_hidden_states, image_rotary_
if attn_type == "torch_sdpa":
joint_hidden_states = block_weight.attn.calculate.apply(q=joint_query, k=joint_key, v=joint_value)
elif attn_type in ["flash_attn3", "sage_attn2", "mlu_flash_attn", "flash_attn2", "mlu_sage_attn"]:
else:
joint_query = joint_query.squeeze(0)
joint_key = joint_key.squeeze(0)
joint_value = joint_value.squeeze(0)
......
......@@ -8,6 +8,7 @@ from safetensors import safe_open
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
from lightx2v_platform.base.global_var import AI_DEVICE
from .infer.offload.transformer_infer import QwenImageOffloadTransformerInfer
from .infer.post_infer import QwenImagePostInfer
......@@ -28,7 +29,7 @@ class QwenImageTransformerModel:
self.model_path = os.path.join(config["model_path"], "transformer")
self.cpu_offload = config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.device = torch.device("cpu") if self.cpu_offload else torch.device(self.config.get("run_device", "cuda"))
self.device = torch.device("cpu") if self.cpu_offload else torch.device(AI_DEVICE)
with open(os.path.join(config["model_path"], "transformer", "config.json"), "r") as f:
transformer_config = json.load(f)
......@@ -124,12 +125,12 @@ class QwenImageTransformerModel:
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.device.type in ["cuda", "mlu", "npu"] and dist.is_initialized():
device = torch.device("{}:{}".format(self.device.type, dist.get_rank()))
if self.config["parallel"]:
device = dist.get_rank()
else:
device = self.device
device = str(self.device)
with safe_open(file_path, framework="pt", device=str(device)) as f:
with safe_open(file_path, framework="pt", device=device) as f:
return {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
for key in f.keys()
......
......@@ -13,6 +13,7 @@ from lightx2v.models.networks.wan.weights.audio.transformer_weights import WanAu
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.utils.utils import load_weights
from lightx2v_platform.base.global_var import AI_DEVICE
class WanAudioModel(WanModel):
......@@ -22,7 +23,6 @@ class WanAudioModel(WanModel):
def __init__(self, model_path, config, device):
self.config = config
self.run_device = self.config.get("run_device", "cuda")
self._load_adapter_ckpt()
super().__init__(model_path, config, device)
......@@ -51,7 +51,7 @@ class WanAudioModel(WanModel):
if not adapter_offload:
if not dist.is_initialized() or not load_from_rank0:
for key in self.adapter_weights_dict:
self.adapter_weights_dict[key] = self.adapter_weights_dict[key].to(torch.device(self.run_device))
self.adapter_weights_dict[key] = self.adapter_weights_dict[key].to(torch.device(AI_DEVICE))
def _init_infer_class(self):
super()._init_infer_class()
......
......@@ -10,10 +10,8 @@ class WanPreInfer:
def __init__(self, config):
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
self.config = config
self.run_device = self.config.get("run_device", "cuda")
self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.task = config["task"]
self.device = torch.device(self.config.get("run_device", "cuda"))
self.freq_dim = config["freq_dim"]
self.dim = config["dim"]
self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False)
......
......@@ -124,7 +124,7 @@ def fuse_scale_shift_kernel(
block_l: int = 128,
block_c: int = 128,
):
assert x.is_cuda and scale.is_cuda
# assert x.is_cuda and scale.is_cuda
assert x.is_contiguous()
if x.dim() == 2:
x = x.unsqueeze(0)
......
......@@ -44,7 +44,6 @@ class WanModel(CompiledMethodsMixin):
super().__init__()
self.model_path = model_path
self.config = config
self.device = self.config.get("run_device", "cuda")
self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.model_type = model_type
......@@ -147,12 +146,12 @@ class WanModel(CompiledMethodsMixin):
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if (self.device.type == "cuda" or self.device.type == "mlu") and dist.is_initialized():
device = torch.device("{}:{}".format(self.device.type, dist.get_rank()))
if self.config["parallel"]:
device = dist.get_rank()
else:
device = self.device
device = str(self.device)
with safe_open(file_path, framework="pt", device=str(device)) as f:
with safe_open(file_path, framework="pt", device=device) as f:
return {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
for key in f.keys()
......
......@@ -3,6 +3,8 @@ from abc import ABC
import torch
import torch.distributed as dist
from lightx2v_platform.base.global_var import AI_DEVICE
class BaseRunner(ABC):
"""Abstract base class for all Runners
......@@ -145,9 +147,9 @@ class BaseRunner(ABC):
if world_size > 1:
if rank == signal_rank:
t = torch.tensor([stopped], dtype=torch.int32).to(device=self.config.get("run_device", "cuda"))
t = torch.tensor([stopped], dtype=torch.int32).to(device=AI_DEVICE)
else:
t = torch.zeros(1, dtype=torch.int32, device=self.config.get("run_device", "cuda"))
t = torch.zeros(1, dtype=torch.int32, device=AI_DEVICE)
dist.broadcast(t, src=signal_rank)
stopped = t.item()
......
......@@ -15,6 +15,7 @@ from lightx2v.utils.global_paras import CALIB
from lightx2v.utils.memory_profiler import peak_memory_decorator
from lightx2v.utils.profiler import *
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
from lightx2v_platform.base.global_var import AI_DEVICE
from .base_runner import BaseRunner
......@@ -59,11 +60,10 @@ class DefaultRunner(BaseRunner):
self.model.compile(self.config.get("compile_shapes", []))
def set_init_device(self):
self.run_device = self.config.get("run_device", "cuda")
if self.config["cpu_offload"]:
self.init_device = torch.device("cpu")
else:
self.init_device = torch.device(self.config.get("run_device", "cuda"))
self.init_device = torch.device(AI_DEVICE)
def load_vfi_model(self):
if self.config["video_frame_interpolation"].get("algo", None) == "rife":
......
......@@ -21,6 +21,7 @@ from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
from lightx2v_platform.base.global_var import AI_DEVICE
@RUNNER_REGISTER("hunyuan_video_1.5")
......@@ -71,7 +72,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if qwen25vl_offload:
qwen25vl_device = torch.device("cpu")
else:
qwen25vl_device = torch.device(self.run_device)
qwen25vl_device = torch.device(AI_DEVICE)
qwen25vl_quantized = self.config.get("qwen25vl_quantized", False)
qwen25vl_quant_scheme = self.config.get("qwen25vl_quant_scheme", None)
......@@ -82,7 +83,6 @@ class HunyuanVideo15Runner(DefaultRunner):
text_encoder = Qwen25VL_TextEncoder(
dtype=torch.float16,
device=qwen25vl_device,
run_device=self.run_device,
checkpoint_path=text_encoder_path,
cpu_offload=qwen25vl_offload,
qwen25vl_quantized=qwen25vl_quantized,
......@@ -94,9 +94,9 @@ class HunyuanVideo15Runner(DefaultRunner):
if byt5_offload:
byt5_device = torch.device("cpu")
else:
byt5_device = torch.device(self.run_device)
byt5_device = torch.device(AI_DEVICE)
byt5 = ByT5TextEncoder(config=self.config, device=byt5_device, run_device=self.run_device, checkpoint_path=self.config["model_path"], cpu_offload=byt5_offload)
byt5 = ByT5TextEncoder(config=self.config, device=byt5_device, checkpoint_path=self.config["model_path"], cpu_offload=byt5_offload)
text_encoders = [text_encoder, byt5]
return text_encoders
......@@ -230,11 +230,10 @@ class HunyuanVideo15Runner(DefaultRunner):
if siglip_offload:
siglip_device = torch.device("cpu")
else:
siglip_device = torch.device(self.run_device)
siglip_device = torch.device(AI_DEVICE)
image_encoder = SiglipVisionEncoder(
config=self.config,
device=siglip_device,
run_device=self.run_device,
checkpoint_path=self.config["model_path"],
cpu_offload=siglip_offload,
)
......@@ -246,7 +245,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device(self.run_device)
vae_device = torch.device(AI_DEVICE)
vae_config = {
"checkpoint_path": self.config["model_path"],
......@@ -265,7 +264,7 @@ class HunyuanVideo15Runner(DefaultRunner):
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device(self.run_device)
vae_device = torch.device(AI_DEVICE)
vae_config = {
"checkpoint_path": self.config["model_path"],
......@@ -275,7 +274,7 @@ class HunyuanVideo15Runner(DefaultRunner):
}
if self.config.get("use_tae", False):
tae_path = self.config["tae_path"]
vae_decoder = self.tae_cls(vae_path=tae_path, dtype=GET_DTYPE()).to(self.run_device)
vae_decoder = self.tae_cls(vae_path=tae_path, dtype=GET_DTYPE()).to(AI_DEVICE)
else:
vae_decoder = self.vae_cls(**vae_config)
return vae_decoder
......@@ -350,7 +349,7 @@ class HunyuanVideo15Runner(DefaultRunner):
self.model_sr.scheduler.step_post()
del self.inputs_sr
torch_ext_module = getattr(torch, self.run_device)
torch_ext_module = getattr(torch, AI_DEVICE)
torch_ext_module.empty_cache()
self.config_sr["is_sr_running"] = False
......@@ -369,10 +368,10 @@ class HunyuanVideo15Runner(DefaultRunner):
text_encoder_output = self.run_text_encoder(self.input_info)
# vision_states is all zero, because we don't have any image input
siglip_output = torch.zeros(1, self.vision_num_semantic_tokens, self.config["hidden_size"], dtype=torch.bfloat16).to(self.run_device)
siglip_mask = torch.zeros(1, self.vision_num_semantic_tokens, dtype=torch.bfloat16, device=torch.device(self.run_device))
siglip_output = torch.zeros(1, self.vision_num_semantic_tokens, self.config["hidden_size"], dtype=torch.bfloat16).to(AI_DEVICE)
siglip_mask = torch.zeros(1, self.vision_num_semantic_tokens, dtype=torch.bfloat16, device=torch.device(AI_DEVICE))
torch_ext_module = getattr(torch, self.run_device)
torch_ext_module = getattr(torch, AI_DEVICE)
torch_ext_module.empty_cache()
gc.collect()
return {
......@@ -400,7 +399,7 @@ class HunyuanVideo15Runner(DefaultRunner):
siglip_output, siglip_mask = self.run_image_encoder(img_ori) if self.config.get("use_image_encoder", True) else None
cond_latents = self.run_vae_encoder(img_ori)
text_encoder_output = self.run_text_encoder(self.input_info)
torch_ext_module = getattr(torch, self.run_device)
torch_ext_module = getattr(torch, AI_DEVICE)
torch_ext_module.empty_cache()
gc.collect()
return {
......@@ -427,9 +426,9 @@ class HunyuanVideo15Runner(DefaultRunner):
target_height = self.target_height
input_image_np = self.resize_and_center_crop(first_frame, target_width=target_width, target_height=target_height)
vision_states = self.image_encoder.encode_images(input_image_np).last_hidden_state.to(device=torch.device(self.run_device), dtype=torch.bfloat16)
vision_states = self.image_encoder.encode_images(input_image_np).last_hidden_state.to(device=torch.device(AI_DEVICE), dtype=torch.bfloat16)
image_encoder_output = self.image_encoder.infer(vision_states)
image_encoder_mask = torch.ones((1, image_encoder_output.shape[1]), dtype=torch.bfloat16, device=torch.device(self.run_device))
image_encoder_mask = torch.ones((1, image_encoder_output.shape[1]), dtype=torch.bfloat16, device=torch.device(AI_DEVICE))
return image_encoder_output, image_encoder_mask
def resize_and_center_crop(self, image, target_width, target_height):
......@@ -480,6 +479,6 @@ class HunyuanVideo15Runner(DefaultRunner):
]
)
ref_images_pixel_values = ref_image_transform(first_frame).unsqueeze(0).unsqueeze(2).to(self.run_device)
ref_images_pixel_values = ref_image_transform(first_frame).unsqueeze(0).unsqueeze(2).to(AI_DEVICE)
cond_latents = self.vae_encoder.encode(ref_images_pixel_values.to(GET_DTYPE()))
return cond_latents
......@@ -15,6 +15,9 @@ from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
def calculate_dimensions(target_area, ratio):
......@@ -85,9 +88,7 @@ class QwenImageRunner(DefaultRunner):
def _run_input_encoder_local_t2i(self):
prompt = self.input_info.prompt
text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
if hasattr(torch, self.run_device):
torch_module = getattr(torch, self.run_device)
torch_module.empty_cache()
torch_device_module.empty_cache()
gc.collect()
return {
"text_encoder_output": text_encoder_output,
......@@ -102,7 +103,7 @@ class QwenImageRunner(DefaultRunner):
if GET_RECORDER_MODE():
width, height = img_ori.size
monitor_cli.lightx2v_input_image_len.observe(width * height)
img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(self.run_device)
img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE)
self.input_info.original_size.append(img_ori.size)
return img, img_ori
......@@ -121,9 +122,7 @@ class QwenImageRunner(DefaultRunner):
for vae_image in text_encoder_output["image_info"]["vae_image_list"]:
image_encoder_output = self.run_vae_encoder(image=vae_image)
image_encoder_output_list.append(image_encoder_output)
if hasattr(torch, self.run_device):
torch_module = getattr(torch, self.run_device)
torch_module.empty_cache()
torch_device_module.empty_cache()
gc.collect()
return {
"text_encoder_output": text_encoder_output,
......@@ -238,9 +237,7 @@ class QwenImageRunner(DefaultRunner):
images = self.vae.decode(latents, self.input_info)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder
if hasattr(torch, self.run_device):
torch_module = getattr(torch, self.run_device)
torch_module.empty_cache()
torch_device_module.empty_cache()
gc.collect()
return images
......@@ -259,9 +256,7 @@ class QwenImageRunner(DefaultRunner):
image.save(f"{input_info.save_result_path}")
del latents, generator
if hasattr(torch, self.run_device):
torch_module = getattr(torch, self.run_device)
torch_module.empty_cache()
torch_device_module.empty_cache()
gc.collect()
# Return (images, audio) - audio is None for default runner
......
......@@ -33,6 +33,7 @@ from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import find_torch_model_path, load_weights, vae_to_comfyui_image_inplace
from lightx2v_platform.base.global_var import AI_DEVICE
warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio")
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.io")
......@@ -450,7 +451,7 @@ class WanAudioRunner(WanRunner): # type:ignore
ref_img = img_path
else:
ref_img = load_image(img_path)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).to(self.run_device)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE)
ref_img, h, w = resize_image(
ref_img,
......@@ -538,15 +539,14 @@ class WanAudioRunner(WanRunner): # type:ignore
def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]:
"""Prepare previous latents for conditioning"""
device = self.run_device
dtype = GET_DTYPE()
tgt_h, tgt_w = self.input_info.target_shape[0], self.input_info.target_shape[1]
prev_frames = torch.zeros((1, 3, self.config["target_video_length"], tgt_h, tgt_w), device=device)
prev_frames = torch.zeros((1, 3, self.config["target_video_length"], tgt_h, tgt_w), device=AI_DEVICE)
if prev_video is not None:
# Extract and process last frames
last_frames = prev_video[:, :, -prev_frame_length:].clone().to(device)
last_frames = prev_video[:, :, -prev_frame_length:].clone().to(AI_DEVICE)
if self.config["model_cls"] != "wan2.2_audio":
last_frames = self.frame_preprocessor.process_prev_frames(last_frames)
prev_frames[:, :, :prev_frame_length] = last_frames
......@@ -574,7 +574,7 @@ class WanAudioRunner(WanRunner): # type:ignore
prev_latents = self.vae_encoder.encode(prev_frames.to(dtype))
frames_n = (nframe - 1) * 4 + 1
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask = torch.ones((1, frames_n, height, width), device=AI_DEVICE, dtype=dtype)
prev_frame_len = max((prev_len - 1) * 4 + 1, 0)
prev_mask[:, prev_frame_len:] = 0
prev_mask = self._wan_mask_rearrange(prev_mask)
......@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_audio_encoder(self):
audio_encoder_path = self.config.get("audio_encoder_path", os.path.join(self.config["model_path"], "TencentGameMate-chinese-hubert-large"))
audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False))
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload, run_device=self.config.get("run_device", "cuda"))
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload)
return model
def load_audio_adapter(self):
......@@ -843,7 +843,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if audio_adapter_offload:
device = torch.device("cpu")
else:
device = torch.device(self.run_device)
device = torch.device(AI_DEVICE)
audio_adapter = AudioAdapter(
attention_head_dim=self.config["dim"] // self.config["num_heads"],
num_attention_heads=self.config["num_heads"],
......@@ -856,7 +856,6 @@ class WanAudioRunner(WanRunner): # type:ignore
quantized=self.config.get("adapter_quantized", False),
quant_scheme=self.config.get("adapter_quant_scheme", None),
cpu_offload=audio_adapter_offload,
run_device=self.run_device,
)
audio_adapter.to(device)
......@@ -892,11 +891,10 @@ class Wan22AudioRunner(WanAudioRunner):
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device(self.run_device)
vae_device = torch.device(AI_DEVICE)
vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device,
"run_device": self.run_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
......@@ -909,11 +907,10 @@ class Wan22AudioRunner(WanAudioRunner):
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device(self.run_device)
vae_device = torch.device(AI_DEVICE)
vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device,
"run_device": self.run_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
......
......@@ -29,6 +29,7 @@ from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
from lightx2v_platform.base.global_var import AI_DEVICE
@RUNNER_REGISTER("wan2.1")
......@@ -65,7 +66,7 @@ class WanRunner(DefaultRunner):
if clip_offload:
clip_device = torch.device("cpu")
else:
clip_device = torch.device(self.run_device)
clip_device = torch.device(AI_DEVICE)
# quant_config
clip_quantized = self.config.get("clip_quantized", False)
if clip_quantized:
......@@ -84,7 +85,6 @@ class WanRunner(DefaultRunner):
image_encoder = CLIPModel(
dtype=torch.float16,
device=clip_device,
run_device=self.run_device,
checkpoint_path=clip_original_ckpt,
clip_quantized=clip_quantized,
clip_quantized_ckpt=clip_quantized_ckpt,
......@@ -102,7 +102,7 @@ class WanRunner(DefaultRunner):
if t5_offload:
t5_device = torch.device("cpu")
else:
t5_device = torch.device(self.run_device)
t5_device = torch.device(AI_DEVICE)
tokenizer_path = os.path.join(self.config["model_path"], "google/umt5-xxl")
# quant_config
t5_quantized = self.config.get("t5_quantized", False)
......@@ -123,7 +123,6 @@ class WanRunner(DefaultRunner):
text_len=self.config["text_len"],
dtype=torch.bfloat16,
device=t5_device,
run_device=self.run_device,
checkpoint_path=t5_original_ckpt,
tokenizer_path=tokenizer_path,
shard_fn=None,
......@@ -142,12 +141,11 @@ class WanRunner(DefaultRunner):
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device(self.run_device)
vae_device = torch.device(AI_DEVICE)
vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
"device": vae_device,
"run_device": self.run_device,
"parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
......@@ -171,7 +169,6 @@ class WanRunner(DefaultRunner):
vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
"device": vae_device,
"run_device": self.run_device,
"parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
......@@ -321,7 +318,7 @@ class WanRunner(DefaultRunner):
self.config["target_video_length"],
lat_h,
lat_w,
device=torch.device(self.run_device),
device=torch.device(AI_DEVICE),
)
if last_frame is not None:
msk[:, 1:-1] = 0
......@@ -343,7 +340,7 @@ class WanRunner(DefaultRunner):
torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
],
dim=1,
).to(self.run_device)
).to(AI_DEVICE)
else:
vae_input = torch.concat(
[
......@@ -351,7 +348,7 @@ class WanRunner(DefaultRunner):
torch.zeros(3, self.config["target_video_length"] - 1, h, w),
],
dim=1,
).to(self.run_device)
).to(AI_DEVICE)
vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE()))
......@@ -534,7 +531,7 @@ class Wan22DenseRunner(WanRunner):
assert img.width == ow and img.height == oh
# to tensor
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.run_device).unsqueeze(1)
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(AI_DEVICE).unsqueeze(1)
vae_encoder_out = self.get_vae_encoder_output(img)
latent_w, latent_h = ow // self.config["vae_stride"][2], oh // self.config["vae_stride"][1]
latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w)
......
......@@ -271,8 +271,7 @@ def get_1d_rotary_pos_embed(
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
run_device = kwds.get("run_device", "cuda")
freqs = torch.outer(pos * interpolation_factor, freqs).to(run_device) # [S, D/2]
freqs = torch.outer(pos * interpolation_factor, freqs).to(AI_DEVICE) # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
......
......@@ -11,7 +11,6 @@ from .posemb_layers import get_nd_rotary_pos_embed
class HunyuanVideo15Scheduler(BaseScheduler):
def __init__(self, config):
super().__init__(config)
self.run_device = torch.device(self.config.get("run_device", "cuda"))
self.reverse = True
self.num_train_timesteps = 1000
self.sample_shift = self.config["sample_shift"]
......@@ -25,13 +24,13 @@ class HunyuanVideo15Scheduler(BaseScheduler):
def prepare(self, seed, latent_shape, image_encoder_output=None):
self.prepare_latents(seed, latent_shape, dtype=torch.bfloat16)
self.set_timesteps(self.infer_steps, device=self.run_device, shift=self.sample_shift)
self.set_timesteps(self.infer_steps, device=AI_DEVICE, shift=self.sample_shift)
self.multitask_mask = self.get_task_mask(self.config["task"], latent_shape[-3])
self.cond_latents_concat, self.mask_concat = self._prepare_cond_latents_and_mask(self.config["task"], image_encoder_output["cond_latents"], self.latents, self.multitask_mask, self.reorg_token)
self.cos_sin = self.prepare_cos_sin((latent_shape[1], latent_shape[2], latent_shape[3]))
def prepare_latents(self, seed, latent_shape, dtype=torch.bfloat16):
self.generator = torch.Generator(device=self.run_device).manual_seed(seed)
self.generator = torch.Generator(device=AI_DEVICE).manual_seed(seed)
self.latents = torch.randn(
1,
latent_shape[0],
......@@ -39,7 +38,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
latent_shape[2],
latent_shape[3],
dtype=dtype,
device=self.run_device,
device=AI_DEVICE,
generator=self.generator,
)
......@@ -127,7 +126,7 @@ class HunyuanVideo15Scheduler(BaseScheduler):
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(rope_dim_list, rope_sizes, theta=self.config["rope_theta"], use_real=True, theta_rescale_factor=1, device=self.run_device)
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(rope_dim_list, rope_sizes, theta=self.config["rope_theta"], use_real=True, theta_rescale_factor=1, device=AI_DEVICE)
cos_half = freqs_cos[:, ::2].contiguous()
sin_half = freqs_sin[:, ::2].contiguous()
cos_sin = torch.cat([cos_half, sin_half], dim=-1)
......@@ -149,9 +148,8 @@ class HunyuanVideo15SRScheduler(HunyuanVideo15Scheduler):
def prepare(self, seed, latent_shape, lq_latents, upsampler, image_encoder_output=None):
dtype = lq_latents.dtype
device = lq_latents.device
self.prepare_latents(seed, latent_shape, lq_latents, dtype=dtype)
self.set_timesteps(self.infer_steps, device=self.run_device, shift=self.sample_shift)
self.set_timesteps(self.infer_steps, device=AI_DEVICE, shift=self.sample_shift)
self.cos_sin = self.prepare_cos_sin((latent_shape[1], latent_shape[2], latent_shape[3]))
tgt_shape = latent_shape[-2:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment