import os import sys import torch from lightx2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights from lightx2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights from lightx2v.models.networks.hunyuan.weights.transformer_weights import HunyuanTransformerWeights from lightx2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer from lightx2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer from lightx2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import HunyuanTransformerInferTaylorCaching, HunyuanTransformerInferTeaCaching import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap from lightx2v.utils.envs import * from loguru import logger class HunyuanModel: pre_weight_class = HunyuanPreWeights post_weight_class = HunyuanPostWeights transformer_weight_class = HunyuanTransformerWeights def __init__(self, model_path, config, device, args): self.model_path = model_path self.config = config self.device = device self.args = args self._init_infer_class() self._init_weights() if GET_RUNNING_FLAG() == "save_naive_quant": assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None" self.save_weights(self.config.naive_quant_path) sys.exit(0) self._init_infer() if config["parallel_attn_type"]: if config["parallel_attn_type"] == "ulysses": ulysses_dist_wrap.parallelize_hunyuan(self) elif config["parallel_attn_type"] == "ring": ring_dist_wrap.parallelize_hunyuan(self) else: raise Exception(f"Unsuppotred parallel_attn_type") if self.config["cpu_offload"]: self.to_cpu() def _init_infer_class(self): self.pre_infer_class = HunyuanPreInfer self.post_infer_class = HunyuanPostInfer if self.config["feature_caching"] == "NoCaching": self.transformer_infer_class = HunyuanTransformerInfer elif self.config["feature_caching"] == "TaylorSeer": self.transformer_infer_class = HunyuanTransformerInferTaylorCaching elif self.config["feature_caching"] == "Tea": self.transformer_infer_class = HunyuanTransformerInferTeaCaching else: raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") def _load_ckpt(self): if self.args.task == "t2v": ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt") else: ckpt_path = os.path.join(self.model_path, "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt") weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"] return weight_dict def _load_ckpt_quant_model(self): assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None" logger.info(f"Loading quant model from {self.config.naive_quant_path}") quant_weights_path = os.path.join(self.config.naive_quant_path, "quant_weights.pth") weight_dict = torch.load(quant_weights_path, map_location=self.device, weights_only=True) return weight_dict def _init_weights(self): if GET_RUNNING_FLAG() == "save_naive_quant" or self.config["mm_config"].get("weight_auto_quant", False) or self.config["mm_config"].get("mm_type", "Default") == "Default": weight_dict = self._load_ckpt() else: weight_dict = self._load_ckpt_quant_model() # init weights self.pre_weight = self.pre_weight_class(self.config) self.post_weight = self.post_weight_class(self.config) self.transformer_weights = self.transformer_weight_class(self.config) # load weights self.pre_weight.load(weight_dict) self.post_weight.load(weight_dict) self.transformer_weights.load(weight_dict) def _init_infer(self): self.pre_infer = self.pre_infer_class(self.config) self.post_infer = self.post_infer_class(self.config) self.transformer_infer = self.transformer_infer_class(self.config) def save_weights(self, save_path): if not os.path.exists(save_path): os.makedirs(save_path) pre_state_dict = self.pre_weight.state_dict() logger.info(pre_state_dict.keys()) post_state_dict = self.post_weight.state_dict() logger.info(post_state_dict.keys()) transformer_state_dict = self.transformer_weights.state_dict() logger.info(transformer_state_dict.keys()) save_dict = {} save_dict.update(pre_state_dict) save_dict.update(post_state_dict) save_dict.update(transformer_state_dict) save_path = os.path.join(save_path, "quant_weights.pth") torch.save(save_dict, save_path) logger.info(f"Save weights to {save_path}") def set_scheduler(self, scheduler): self.scheduler = scheduler self.pre_infer.set_scheduler(scheduler) self.post_infer.set_scheduler(scheduler) self.transformer_infer.set_scheduler(scheduler) def to_cpu(self): self.pre_weight.to_cpu() self.post_weight.to_cpu() self.transformer_weights.to_cpu() def to_cuda(self): self.pre_weight.to_cuda() self.post_weight.to_cuda() self.transformer_weights.to_cuda() @torch.no_grad() def infer(self, inputs): if self.config["cpu_offload"]: self.pre_weight.to_cuda() self.post_weight.to_cuda() inputs = self.pre_infer.infer(self.pre_weight, inputs) inputs = self.transformer_infer.infer(self.transformer_weights, *inputs) self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, *inputs) if self.config["cpu_offload"]: self.pre_weight.to_cpu() self.post_weight.to_cpu() if self.config["feature_caching"] == "Tea": self.scheduler.cnt += 1 if self.scheduler.cnt == self.scheduler.num_steps: self.scheduler.cnt = 0