Commit daf4c74e authored by helloyongyang's avatar helloyongyang Committed by Yang Yong(雍洋)
Browse files

first commit

parent 6c79160f
import torch
from qtorch.quant import float_quantize
class BaseQuantizer(object):
def __init__(self, bit, symmetric, granularity, **kwargs):
self.bit = bit
self.sym = symmetric
self.granularity = granularity
self.kwargs = kwargs
if self.granularity == 'per_group':
self.group_size = self.kwargs['group_size']
self.calib_algo = self.kwargs.get('calib_algo', 'minmax')
def get_tensor_range(self, tensor):
if self.calib_algo == 'minmax':
return self.get_minmax_range(tensor)
elif self.calib_algo == 'mse':
return self.get_mse_range(tensor)
else:
raise ValueError(f'Unsupported calibration algorithm: {self.calib_algo}')
def get_minmax_range(self, tensor):
if self.granularity == 'per_tensor':
max_val = torch.max(tensor)
min_val = torch.min(tensor)
else:
max_val = tensor.amax(dim=-1, keepdim=True)
min_val = tensor.amin(dim=-1, keepdim=True)
return (min_val, max_val)
def get_mse_range(self, tensor):
raise NotImplementedError
def get_qparams(self, tensor_range, device):
min_val, max_val = tensor_range[0], tensor_range[1]
qmin = self.qmin.to(device)
qmax = self.qmax.to(device)
if self.sym:
abs_max = torch.max(max_val.abs(), min_val.abs())
abs_max = abs_max.clamp(min=1e-5)
scales = abs_max / qmax
zeros = torch.tensor(0.0)
else:
scales = (max_val - min_val).clamp(min=1e-5) / (qmax - qmin)
zeros = (qmin - torch.round(min_val / scales)).clamp(qmin, qmax)
return scales, zeros, qmax, qmin
def reshape_tensor(self, tensor, allow_padding=False):
if self.granularity == 'per_group':
t = tensor.reshape(-1, self.group_size)
else:
t = tensor
return t
def restore_tensor(self, tensor, shape):
if tensor.shape == shape:
t = tensor
else:
t = tensor.reshape(shape)
return t
def get_tensor_qparams(self, tensor):
tensor = self.reshape_tensor(tensor)
tensor_range = self.get_tensor_range(tensor)
scales, zeros, qmax, qmin = self.get_qparams(tensor_range, tensor.device)
return tensor, scales, zeros, qmax, qmin
def fake_quant_tensor(self, tensor):
org_shape = tensor.shape
org_dtype = tensor.dtype
tensor, scales, zeros, qmax, qmin = self.get_tensor_qparams(tensor)
tensor = self.quant_dequant(tensor, scales, zeros, qmax, qmin)
tensor = self.restore_tensor(tensor, org_shape).to(org_dtype)
return tensor
def real_quant_tensor(self, tensor):
org_shape = tensor.shape
tensor, scales, zeros, qmax, qmin = self.get_tensor_qparams(tensor)
tensor = self.quant(tensor, scales, zeros, qmax, qmin)
tensor = self.restore_tensor(tensor, org_shape)
if self.sym == True:
zeros = None
return tensor, scales, zeros
class IntegerQuantizer(BaseQuantizer):
def __init__(self, bit, symmetric, granularity, **kwargs):
super().__init__(bit, symmetric, granularity, **kwargs)
if 'int_range' in self.kwargs:
self.qmin = self.kwargs['int_range'][0]
self.qmax = self.kwargs['int_range'][1]
else:
if self.sym:
self.qmin = -(2 ** (self.bit - 1))
self.qmax = 2 ** (self.bit - 1) - 1
else:
self.qmin = 0.0
self.qmax = 2**self.bit - 1
self.qmin = torch.tensor(self.qmin)
self.qmax = torch.tensor(self.qmax)
self.dst_nbins = 2**bit
def quant(self, tensor, scales, zeros, qmax, qmin):
tensor = torch.clamp(torch.round(tensor / scales) + zeros, qmin, qmax)
return tensor
def dequant(self, tensor, scales, zeros):
tensor = (tensor - zeros) * scales
return tensor
def quant_dequant(self, tensor, scales, zeros, qmax, qmin,):
tensor = self.quant(tensor, scales, zeros, qmax, qmin)
tensor = self.dequant(tensor, scales, zeros)
return tensor
class FloatQuantizer(BaseQuantizer):
def __init__(self, bit, symmetric, granularity, **kwargs):
super().__init__(bit, symmetric, granularity, **kwargs)
assert self.bit in ['e4m3', 'e5m2'], f'Unsupported bit configuration: {self.bit}'
assert self.sym == True
if self.bit == 'e4m3':
self.e_bits = 4
self.m_bits = 3
self.fp_dtype = torch.float8_e4m3fn
elif self.bit == 'e5m2':
self.e_bits = 5
self.m_bits = 2
self.fp_dtype = torch.float8_e5m2
else:
raise ValueError(f'Unsupported bit configuration: {self.bit}')
finfo = torch.finfo(self.fp_dtype)
self.qmin, self.qmax = finfo.min, finfo.max
self.qmax = torch.tensor(self.qmax)
self.qmin = torch.tensor(self.qmin)
def quant(self, tensor, scales, zeros, qmax, qmin):
scaled_tensor = tensor / scales + zeros
scaled_tensor = torch.clip(
scaled_tensor, self.qmin.cuda(), self.qmax.cuda()
)
org_dtype = scaled_tensor.dtype
q_tensor = float_quantize(
scaled_tensor.float(), self.e_bits, self.m_bits, rounding='nearest'
)
q_tensor.to(org_dtype)
return q_tensor
def dequant(self, tensor, scales, zeros):
tensor = (tensor - zeros) * scales
return tensor
def quant_dequant(self, tensor, scales, zeros, qmax, qmin):
tensor = self.quant(tensor, scales, zeros, qmax, qmin)
tensor = self.dequant(tensor, scales, zeros)
return tensor
if __name__ == '__main__':
weight = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda()
quantizer = IntegerQuantizer(4, False, 'per_group', group_size=128)
q_weight = quantizer.fake_quant_tensor(weight)
print(weight)
print(q_weight)
print(f"cosine = {torch.cosine_similarity(weight.view(1, -1).to(torch.float64), q_weight.view(1, -1).to(torch.float64))}")
realq_weight, scales, zeros = quantizer.real_quant_tensor(weight)
print(f"realq_weight = {realq_weight}, {realq_weight.shape}")
print(f"scales = {scales}, {scales.shape}")
print(f"zeros = {zeros}, {zeros.shape}")
weight = torch.randn(8192, 4096, dtype=torch.bfloat16).cuda()
quantizer = FloatQuantizer('e4m3', True, 'per_channel')
q_weight = quantizer.fake_quant_tensor(weight)
print(weight)
print(q_weight)
print(f"cosine = {torch.cosine_similarity(weight.view(1, -1).to(torch.float64), q_weight.view(1, -1).to(torch.float64))}")
realq_weight, scales, zeros = quantizer.real_quant_tensor(weight)
print(f"realq_weight = {realq_weight}, {realq_weight.shape}")
print(f"scales = {scales}, {scales.shape}")
print(f"zeros = {zeros}")
class Register(dict):
def __init__(self, *args, **kwargs):
super(Register, self).__init__(*args, **kwargs)
self._dict = {}
def __call__(self, target_or_name):
if callable(target_or_name):
return self.register(target_or_name)
else:
return lambda x: self.register(x, key=target_or_name)
def register(self, target, key=None):
if not callable(target):
raise Exception(f'Error: {target} must be callable!')
if key is None:
key = target.__name__
if key in self._dict:
raise Exception(f'{key} already exists.')
self[key] = target
return target
def __setitem__(self, key, value):
self._dict[key] = value
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def __str__(self):
return str(self._dict)
def keys(self):
return self._dict.keys()
def values(self):
return self._dict.values()
def items(self):
return self._dict.items()
MM_WEIGHT_REGISTER = Register()
RMS_WEIGHT_REGISTER = Register()
LN_WEIGHT_REGISTER = Register()
CONV3D_WEIGHT_REGISTER = Register()
CONV2D_WEIGHT_REGISTER = Register()
import os
from einops import rearrange
import torch
import torchvision
import numpy as np
import imageio
import random
import os
def seed_all(seed):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
"""save videos by video tensor
copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
Args:
videos (torch.Tensor): video tensor predicted by the model
path (str): path to save video
rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
n_rows (int, optional): Defaults to 1.
fps (int, optional): video save fps. Defaults to 8.
"""
videos = rearrange(videos, "b c t h w -> t b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=n_rows)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = torch.clamp(x, 0, 1)
x = (x * 255).numpy().astype(np.uint8)
outputs.append(x)
os.makedirs(os.path.dirname(path), exist_ok=True)
imageio.mimsave(path, outputs, fps=fps)
def cache_video(
tensor,
save_file,
fps=30,
suffix=".mp4",
nrow=8,
normalize=True,
value_range=(-1, 1),
retry=5,
):
cache_file = save_file
# save to cache
error = None
for _ in range(retry):
try:
# preprocess
tensor = tensor.clamp(min(value_range), max(value_range))
tensor = torch.stack(
[
torchvision.utils.make_grid(
u, nrow=nrow, normalize=normalize, value_range=value_range
)
for u in tensor.unbind(2)
],
dim=1,
).permute(1, 2, 3, 0)
tensor = (tensor * 255).type(torch.uint8).cpu()
# write video
writer = imageio.get_writer(cache_file, fps=fps, codec="libx264", quality=8)
for frame in tensor.numpy():
writer.append_data(frame)
writer.close()
return cache_file
except Exception as e:
error = e
continue
else:
print(f"cache_video failed, error: {error}", flush=True)
return None
import argparse
import torch
import torch.distributed as dist
import os
import time
import gc
import json
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image
from lightx2v.text2v.models.text_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.text2v.models.text_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.text2v.models.text_encoders.hf.t5.model import T5EncoderModel
from lightx2v.text2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.text2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerFeatureCaching
from lightx2v.text2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.text2v.models.schedulers.wan.feature_caching.scheduler import WanSchedulerFeatureCaching
from lightx2v.text2v.models.networks.hunyuan.model import HunyuanModel
from lightx2v.text2v.models.networks.wan.model import WanModel
from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.text2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
from lightx2v.common.ops import *
from lightx2v.image2v.models.wan.model import CLIPModel
def load_models(args, model_config):
if model_config['parallel_attn']:
cur_rank = dist.get_rank() # 获取当前进程的 rank
torch.cuda.set_device(cur_rank) # 设置当前进程的 CUDA 设备
image_encoder = None
if args.cpu_offload:
init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
if args.model_cls == "hunyuan":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(args.model_path, "text_encoder"), init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(args.model_path, "text_encoder_2"), init_device)
text_encoders = [text_encoder_1, text_encoder_2]
model = HunyuanModel(args.model_path, model_config)
vae_model = VideoEncoderKLCausal3DModel(args.model_path, dtype=torch.float16, device=init_device)
elif args.model_cls == "wan2.1":
text_encoder = T5EncoderModel(
text_len=model_config["text_len"],
dtype=torch.bfloat16,
device=torch.device("cuda"),
checkpoint_path=os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(args.model_path, "google/umt5-xxl"),
shard_fn=None,
)
text_encoders = [text_encoder]
model = WanModel(args.model_path, model_config)
vae_model = WanVAE(vae_pth=os.path.join(args.model_path, "Wan2.1_VAE.pth"), device=torch.device("cuda"))
if args.task == 'i2v':
image_encoder = CLIPModel(
dtype=torch.float16,
device=torch.device("cuda"),
checkpoint_path=os.path.join(args.model_path,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
tokenizer_path=os.path.join(args.model_path, "xlm-roberta-large"))
else:
raise NotImplementedError(f"Unsupported model class: {args.model_cls}")
return model, text_encoders, vae_model, image_encoder
def set_target_shape(args):
if args.model_cls == 'hunyuan':
vae_scale_factor = 2 ** (4 - 1)
args.target_shape = (
1,
16,
(args.target_video_length - 1) // 4 + 1,
int(args.target_height) // vae_scale_factor,
int(args.target_width) // vae_scale_factor,
)
elif args.model_cls == 'wan2.1':
if args.task == 'i2v':
args.target_shape = (
16,
21,
args.lat_h,
args.lat_w
)
elif args.task == 't2v':
args.target_shape = (
16,
(args.target_video_length - 1) // 4 + 1,
int(args.target_height) // args.vae_stride[1],
int(args.target_width) // args.vae_stride[2],
)
def run_image_encoder(args, image_encoder, vae_model):
if args.model_cls == "hunyuan":
return None
elif args.model_cls == 'wan2.1':
img = Image.open(args.image_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = image_encoder.visual([img[:, None, :, :]]).squeeze(0).to(torch.bfloat16)
h, w = img.shape[1:]
aspect_ratio = h / w
max_area = args.target_height * args.target_width
lat_h = round(
np.sqrt(max_area * aspect_ratio) // args.vae_stride[1] //
args.patch_size[1] * args.patch_size[1])
lat_w = round(
np.sqrt(max_area / aspect_ratio) // args.vae_stride[2] //
args.patch_size[2] * args.patch_size[2])
h = lat_h * args.vae_stride[1]
w = lat_w * args.vae_stride[2]
args.lat_h = lat_h
args.lat_w = lat_w
msk = torch.ones(1, 81, lat_h, lat_w, device=torch.device('cuda'))
msk[:, 1:] = 0
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
vae_encode_out = vae_model.encode([
torch.concat([
torch.nn.functional.interpolate(
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
0, 1),
torch.zeros(3, 80, h, w)
], dim=1).cuda()
])[0]
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}
else:
raise NotImplementedError(f"Unsupported model class: {model_cls}")
def run_text_encoder(args, text, text_encoders, model_config):
text_encoder_output = {}
if args.model_cls == "hunyuan":
for i, encoder in enumerate(text_encoders):
text_state, attention_mask = encoder.infer(text, args)
text_encoder_output[f"text_encoder_{i+1}_text_states"] = text_state.to(dtype=torch.bfloat16)
text_encoder_output[f"text_encoder_{i+1}_attention_mask"] = attention_mask
elif args.model_cls == "wan2.1":
n_prompt = model_config.get("sample_neg_prompt", "")
context = text_encoders[0].infer([text], args)
context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], args)
text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null
else:
raise NotImplementedError(f"Unsupported model type: {args.model_cls}")
return text_encoder_output
def init_scheduler(args):
if args.model_cls == "hunyuan":
if args.feature_caching == "NoCaching":
scheduler = HunyuanScheduler(args)
elif args.feature_caching == "TaylorSeer":
scheduler = HunyuanSchedulerFeatureCaching(args)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}")
elif args.model_cls == "wan2.1":
if args.feature_caching == "NoCaching":
scheduler = WanScheduler(args)
elif args.feature_caching == "Tea":
scheduler = WanSchedulerFeatureCaching(args)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}")
else:
raise NotImplementedError(f"Unsupported model class: {args.model_cls}")
return scheduler
def run_main_inference(args, model, text_encoder_output, image_encoder_output):
for step_index in range(model.scheduler.infer_steps):
torch.cuda.synchronize()
time1 = time.time()
model.scheduler.step_pre(step_index=step_index)
torch.cuda.synchronize()
time2 = time.time()
model.infer(text_encoder_output, image_encoder_output, args)
torch.cuda.synchronize()
time3 = time.time()
model.scheduler.step_post()
torch.cuda.synchronize()
time4 = time.time()
print(f"step {step_index} infer time: {time3 - time2}")
print(f"step {step_index} all time: {time4 - time1}")
print("*" * 10)
return model.scheduler.latents, model.scheduler.generator
def run_vae(latents, generator, args):
images = vae_model.decode(latents, generator=generator, args=args)
return images
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_path", type=str, default=None)
parser.add_argument("--image_path", type=str, default=None)
parser.add_argument('--save_video_path', type=str, default='./output_ligthx2v.mp4')
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--infer_steps", type=int, required=True)
parser.add_argument("--target_video_length", type=int, required=True)
parser.add_argument("--target_width", type=int, required=True)
parser.add_argument("--target_height", type=int, required=True)
parser.add_argument("--attention_type", type=str, required=True)
parser.add_argument("--sample_neg_prompt", type=str, default="")
parser.add_argument("--sample_guide_scale", type=float, default=5.0)
parser.add_argument("--sample_shift", type=float, default=5.0)
parser.add_argument('--do_mm_calib', action='store_true')
parser.add_argument('--cpu_offload', action='store_true')
parser.add_argument('--feature_caching', choices=["NoCaching", "TaylorSeer", "Tea"], default="NoCaching")
parser.add_argument('--mm_config', default=None)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--parallel_attn', action='store_true')
parser.add_argument('--max_area', action='store_true')
parser.add_argument('--vae_stride', default=(4, 8, 8))
parser.add_argument('--patch_size', default=(1, 2, 2))
parser.add_argument("--teacache_thresh", type=float, default=0.26)
parser.add_argument("--use_ret_steps", action="store_true", default=False)
args = parser.parse_args()
start_time = time.time()
print(f"args: {args}")
seed_all(args.seed)
if args.parallel_attn:
dist.init_process_group(backend='nccl')
if args.mm_config:
mm_config = json.loads(args.mm_config)
else:
mm_config = None
model_config = {
"task": args.task,
"attention_type": args.attention_type,
"sample_neg_prompt": args.sample_neg_prompt,
"mm_config": mm_config,
"do_mm_calib": args.do_mm_calib,
"cpu_offload": args.cpu_offload,
"feature_caching": args.feature_caching,
"parallel_attn": args.parallel_attn
}
if args.config_path is not None:
with open(args.config_path, "r") as f:
config = json.load(f)
model_config.update(config)
print(f"model_config: {model_config}")
model, text_encoders, vae_model, image_encoder = load_models(args, model_config)
if args.task in ['i2v']:
image_encoder_output = run_image_encoder(args, image_encoder, vae_model)
else:
image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None}
text_encoder_output = run_text_encoder(args, args.prompt, text_encoders, model_config)
set_target_shape(args)
scheduler = init_scheduler(args)
model.set_scheduler(scheduler)
gc.collect()
torch.cuda.empty_cache()
if args.cpu_offload:
model.to_cuda()
latents, generator = run_main_inference(args, model, text_encoder_output, image_encoder_output)
if args.cpu_offload:
model.to_cpu()
gc.collect()
torch.cuda.empty_cache()
images = run_vae(latents, generator, args)
if args.model_cls == "wan2.1":
cache_video(tensor=images, save_file=args.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else:
save_videos_grid(images, args.save_video_path, fps=24)
end_time = time.time()
print(f"Total time: {end_time - start_time}")
\ No newline at end of file
#!/bin/bash
# model_path=/mnt/nvme1/yongyang/models/hy/ckpts # H800-13
model_path=/mnt/nvme0/yongyang/projects/hy/HunyuanVideo/ckpts # H800-14
# model_path=/workspace/ckpts_link # H800-14
# export CUDA_VISIBLE_DEVICES=2
# python main.py \
# --model_cls hunyuan \
# --model_path $model_path \
# --prompt "A cat walks on the grass, realistic style." \
# --infer_steps 20 \
# --target_video_length 33 \
# --target_height 720 \
# --target_width 1280 \
# --attention_type flash_attn3 \
# --save_video_path ./output_lightx2v_int8.mp4 \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
# export CUDA_VISIBLE_DEVICES=0,1,2,3
# torchrun --nproc_per_node=4 main.py \
# --model_cls hunyuan \
# --model_path $model_path \
# --prompt "A cat walks on the grass, realistic style." \
# --infer_steps 20 \
# --target_video_length 33 \
# --target_height 720 \
# --target_width 1280 \
# --attention_type flash_attn2 \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
# --parallel_attn
export CUDA_VISIBLE_DEVICES=2
python main.py \
--model_cls hunyuan \
--model_path $model_path \
--prompt "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." \
--infer_steps 50 \
--target_video_length 65 \
--target_height 480 \
--target_width 640 \
--attention_type flash_attn3 \
--cpu_offload \
--feature_caching TaylorSeer \
--save_video_path ./output_lightx2v_offload_TaylorSeer.mp4 \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
#!/bin/bash
export CUDA_VISIBLE_DEVICES=2
# model_path=/mnt/nvme1/yongyang/models/hy/ckpts # H800-13
model_path=/mnt/nvme0/yongyang/projects/wan/Wan2.1-I2V-14B-480P # H800-14
config_path=/mnt/nvme0/yongyang/projects/wan/Wan2.1-I2V-14B-480P/config.json
python main.py \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--infer_steps 40 \
--target_video_length 81 \
--target_width 832 \
--target_height 480 \
--attention_type flash_attn3 \
--seed 42 \
--sample_neg_promp 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--config_path $config_path \
--save_video_path ./output_lightx2v_seed42_fp8_base.mp4 \
--sample_guide_scale 5 \
--sample_shift 5 \
--image_path ./i2v_input.JPG \
--mm_config '{"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
# --feature_caching Tea \
# --use_ret_steps \
\ No newline at end of file
#!/bin/bash
export CUDA_VISIBLE_DEVICES=2
# model_path=/mnt/nvme1/yongyang/models/hy/ckpts # H800-13
model_path=/mnt/nvme0/yongyang/projects/wan/Wan2.1-T2V-1.3B # H800-14
config_path=/mnt/nvme0/yongyang/projects/wan/Wan2.1-T2V-1.3B/config.json
python main.py \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--infer_steps 50 \
--target_video_length 81 \
--target_width 832 \
--target_height 480 \
--attention_type flash_attn3 \
--seed 42 \
--sample_neg_promp 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--config_path $config_path \
--save_video_path ./output_lightx2v_seed42.mp4 \
--sample_guide_scale 6 \
--sample_shift 8
# --mm_config '{"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment