Commit aefaf565 authored by Xinchi Huang's avatar Xinchi Huang Committed by Yang Yong(雍洋)
Browse files

parallel attention (#1)



* parallel attention

* Update run_hunyuan_t2v.sh

* Update main.py

---------
Co-authored-by: default avatar“de1star” <“843414674@qq.com”>
Co-authored-by: default avatarYang Yong <yongyang1030@163.com>
parent daf4c74e
*.pth
*.pt
*.onnx
*.pk
*.model
*.zip
*.tar
*.pyc
*.log
*.o
*.so
*.a
*.exe
*.out
.idea
**.DS_Store**
**/__pycache__/**
**.swp
.vscode/
.env
.log
*.pid
*.ipynb*
*.mp4
...@@ -26,8 +26,12 @@ def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_att ...@@ -26,8 +26,12 @@ def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_att
# 获取序列长度和文本相关的长度 # 获取序列长度和文本相关的长度
seq_len = q.shape[0] seq_len = q.shape[0]
if len(cu_seqlens_qkv) == 3:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度 txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
elif len(cu_seqlens_qkv) == 2:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = None
# 获取查询张量的头数和隐藏维度 # 获取查询张量的头数和隐藏维度
_, heads, hidden_dims = q.shape _, heads, hidden_dims = q.shape
...@@ -58,8 +62,9 @@ def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_att ...@@ -58,8 +62,9 @@ def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_att
cu_seqlens_qkv = torch.zeros([3], dtype=torch.int32, device="cuda") cu_seqlens_qkv = torch.zeros([3], dtype=torch.int32, device="cuda")
s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度 s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置 s1 = s # 当前样本的结束位置
s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置
cu_seqlens_qkv[1] = s1 # 设置累积序列长度 cu_seqlens_qkv[1] = s1 # 设置累积序列长度
if txt_mask_len:
s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置
cu_seqlens_qkv[2] = s2 # 设置累积序列长度 cu_seqlens_qkv[2] = s2 # 设置累积序列长度
max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度 max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度
......
import functools import functools
from lightx2v.attentions.distributed.ulysses.attn import ulysses_attn from lightx2v.attentions.distributed.ulysses.attn import ulysses_attn
from lightx2v.attentions.distributed.utils.process import pre_process, post_process
def parallelize_hunyuan(hunyuan_model): def parallelize_hunyuan(hunyuan_model):
from lightx2v.attentions.distributed.utils.hunyuan.processor import pre_process, post_process
"""将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。 """将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。
参数: 参数:
...@@ -15,7 +15,7 @@ def parallelize_hunyuan(hunyuan_model): ...@@ -15,7 +15,7 @@ def parallelize_hunyuan(hunyuan_model):
original_infer = hunyuan_model.infer original_infer = hunyuan_model.infer
@functools.wraps(hunyuan_model.__class__.infer) # 保留原始推理方法的元信息 @functools.wraps(hunyuan_model.__class__.infer) # 保留原始推理方法的元信息
def new_infer(self, text_encoders_output, args): def new_infer(self, text_encoders_output, image_encoder_output, args):
"""新的推理方法,处理输入并调用原始推理方法。 """新的推理方法,处理输入并调用原始推理方法。
参数: 参数:
...@@ -39,8 +39,8 @@ def parallelize_hunyuan(hunyuan_model): ...@@ -39,8 +39,8 @@ def parallelize_hunyuan(hunyuan_model):
) )
# 调用原始推理方法,获取输出 # 调用原始推理方法,获取输出
output = original_infer( original_infer(
text_encoders_output, args text_encoders_output, image_encoder_output, args
) )
# 对输出进行后处理 # 对输出进行后处理
...@@ -58,3 +58,28 @@ def parallelize_hunyuan(hunyuan_model): ...@@ -58,3 +58,28 @@ def parallelize_hunyuan(hunyuan_model):
# 将新的推理方法绑定到 Hunyuan 模型实例 # 将新的推理方法绑定到 Hunyuan 模型实例
new_infer = new_infer.__get__(hunyuan_model) new_infer = new_infer.__get__(hunyuan_model)
hunyuan_model.infer = new_infer # 替换原始推理方法 hunyuan_model.infer = new_infer # 替换原始推理方法
def parallelize_wan(wan_model):
from lightx2v.attentions.distributed.utils.wan.processor import pre_process, post_process
wan_model.transformer_infer.parallel_attention = ulysses_attn
original_infer = wan_model.transformer_infer.infer
@functools.wraps(wan_model.transformer_infer.__class__.infer) # 保留原始推理方法的元信息
def new_infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
x = pre_process(
x
)
x = original_infer(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
x = post_process(
x
)
return x
new_infer = new_infer.__get__(wan_model.transformer_infer)
wan_model.transformer_infer.infer = new_infer # 替换原始推理方法
\ No newline at end of file
import torch
import torch.distributed as dist
def pre_process(latent_model_input, freqs_cos, freqs_sin):
'''
对输入的潜在模型数据和频率数据进行预处理,进行切分以适应分布式计算。
参数:
latent_model_input (torch.Tensor): 输入的潜在模型数据,形状为 [batch_size, channels, temporal_size, height, width]
freqs_cos (torch.Tensor): 余弦频率数据,形状为 [batch_size, channels, temporal_size, height, width]
freqs_sin (torch.Tensor): 正弦频率数据,形状为 [batch_size, channels, temporal_size, height, width]
返回:
tuple: 处理后的 latent_model_input, freqs_cos, freqs_sin 和切分维度 split_dim
'''
# 获取当前进程的世界大小和当前进程的排名
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
# 根据输入的形状确定切分维度
if latent_model_input.shape[-2] // 2 % world_size == 0:
split_dim = -2 # 按高度切分
elif latent_model_input.shape[-1] // 2 % world_size == 0:
split_dim = -1 # 按宽度切分
else:
raise ValueError(f"Cannot split video sequence into world size ({world_size}) parts evenly")
# 获取时间维度、处理后的高度和宽度
temporal_size, h, w = latent_model_input.shape[2], latent_model_input.shape[3] // 2, latent_model_input.shape[4] // 2
# 按照确定的维度切分潜在模型输入
latent_model_input = torch.chunk(latent_model_input, world_size, dim=split_dim)[cur_rank]
# 处理余弦频率数据
dim_thw = freqs_cos.shape[-1] # 获取频率数据的最后一个维度
freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw) # 重塑为 [temporal_size, height, width, dim_thw]
freqs_cos = torch.chunk(freqs_cos, world_size, dim=split_dim - 1)[cur_rank] # 切分频率数据
freqs_cos = freqs_cos.reshape(-1, dim_thw) # 重塑为 [batch_size, dim_thw]
# 处理正弦频率数据
dim_thw = freqs_sin.shape[-1] # 获取频率数据的最后一个维度
freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw) # 重塑为 [temporal_size, height, width, dim_thw]
freqs_sin = torch.chunk(freqs_sin, world_size, dim=split_dim - 1)[cur_rank] # 切分频率数据
freqs_sin = freqs_sin.reshape(-1, dim_thw) # 重塑为 [batch_size, dim_thw]
return latent_model_input, freqs_cos, freqs_sin, split_dim # 返回处理后的数据
def post_process(output, split_dim):
"""对输出进行后处理,收集所有进程的输出并合并。
参数:
output (torch.Tensor): 当前进程的输出,形状为 [batch_size, ...]
split_dim (int): 切分维度,用于合并输出
返回:
torch.Tensor: 合并后的输出,形状为 [world_size * batch_size, ...]
"""
# 获取当前进程的世界大小
world_size = dist.get_world_size()
# 创建一个列表,用于存储所有进程的输出
gathered_outputs = [torch.empty_like(output) for _ in range(world_size)]
# 收集所有进程的输出
dist.all_gather(gathered_outputs, output)
# 在指定的维度上合并所有进程的输出
combined_output = torch.cat(gathered_outputs, dim=split_dim)
return combined_output # 返回合并后的输出
from re import split
import torch
import torch.distributed as dist
def pre_process(x):
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
x = torch.chunk(
x, world_size, dim=0
)[cur_rank]
return x
def post_process(x):
# 获取当前进程的世界大小
world_size = dist.get_world_size()
# 创建一个列表,用于存储所有进程的输出
gathered_x = [torch.empty_like(x) for _ in range(world_size)]
# 收集所有进程的输出
dist.all_gather(gathered_x, x)
# 在指定的维度上合并所有进程的输出
combined_output = torch.cat(gathered_x, dim=0)
return combined_output # 返回合并后的输出
\ No newline at end of file
import torch import torch
from .utils import compute_freqs, apply_rotary_emb, rms_norm from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb, rms_norm
from lightx2v.attentions import attention from lightx2v.attentions import attention
...@@ -12,6 +12,7 @@ class WanTransformerInfer: ...@@ -12,6 +12,7 @@ class WanTransformerInfer:
self.num_heads = config["num_heads"] self.num_heads = config["num_heads"]
self.head_dim = config["dim"] // config["num_heads"] self.head_dim = config["dim"] // config["num_heads"]
self.window_size = config.get("window_size", (-1, -1)) self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
...@@ -45,6 +46,7 @@ class WanTransformerInfer: ...@@ -45,6 +46,7 @@ class WanTransformerInfer:
freqs, freqs,
context, context,
) )
return x return x
def infer_block( def infer_block(
...@@ -69,13 +71,17 @@ class WanTransformerInfer: ...@@ -69,13 +71,17 @@ class WanTransformerInfer:
v = weights.self_attn_v.apply(norm1_out).view(s, n, d) v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
if not self.parallel_attention:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs) freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
q = apply_rotary_emb(q, freqs_i) q = apply_rotary_emb(q, freqs_i)
k = apply_rotary_emb(k, freqs_i) k = apply_rotary_emb(k, freqs_i)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=seq_lens) cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=seq_lens)
if not self.parallel_attention:
attn_out = attention( attn_out = attention(
attention_type=self.attention_type, attention_type=self.attention_type,
q=q, q=q,
...@@ -86,6 +92,17 @@ class WanTransformerInfer: ...@@ -86,6 +92,17 @@ class WanTransformerInfer:
max_seqlen_q=lq, max_seqlen_q=lq,
max_seqlen_kv=lk, max_seqlen_kv=lk,
) )
else:
attn_out = self.parallel_attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
img_qkv_len=q.shape[0],
cu_seqlens_qkv=cu_seqlens_q
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
y = weights.self_attn_o.apply(attn_out) y = weights.self_attn_o.apply(attn_out)
x = x + y * embed0[2].squeeze(0) x = x + y * embed0[2].squeeze(0)
......
import torch import torch
import sgl_kernel import sgl_kernel
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.distributed as dist
def rms_norm(x, weight, eps): def rms_norm(x, weight, eps):
...@@ -27,6 +28,41 @@ def compute_freqs(c, grid_sizes, freqs): ...@@ -27,6 +28,41 @@ def compute_freqs(c, grid_sizes, freqs):
return freqs_i return freqs_i
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.ones(
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
def compute_freqs_dist(s, c, grid_sizes, freqs):
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank):((cur_rank + 1) *
s_per_rank), :, :]
return freqs_i_rank
def apply_rotary_emb(x, freqs_i): def apply_rotary_emb(x, freqs_i):
n = x.size(1) n = x.size(1)
seq_len = freqs_i.size(0) seq_len = freqs_i.size(0)
......
...@@ -14,6 +14,7 @@ from lightx2v.text2v.models.networks.wan.infer.transformer_infer import ( ...@@ -14,6 +14,7 @@ from lightx2v.text2v.models.networks.wan.infer.transformer_infer import (
) )
from lightx2v.text2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferFeatureCaching from lightx2v.text2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferFeatureCaching
from safetensors import safe_open from safetensors import safe_open
from lightx2v.attentions.distributed.ulysses.wrap import parallelize_wan
class WanModel: class WanModel:
...@@ -28,6 +29,9 @@ class WanModel: ...@@ -28,6 +29,9 @@ class WanModel:
self._init_weights() self._init_weights()
self._init_infer() self._init_infer()
if config['parallel_attn']:
parallelize_wan(self)
def _init_infer_class(self): def _init_infer_class(self):
self.pre_infer_class = WanPreInfer self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer self.post_infer_class = WanPostInfer
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist
from einops import rearrange from einops import rearrange
__all__ = [ __all__ = [
...@@ -706,9 +707,11 @@ class WanVAE: ...@@ -706,9 +707,11 @@ class WanVAE:
vae_pth="cache/vae_step_411000.pth", vae_pth="cache/vae_step_411000.pth",
dtype=torch.float, dtype=torch.float,
device="cuda", device="cuda",
parallel=False,
): ):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.parallel = parallel
mean = [ mean = [
-0.7571, -0.7571,
...@@ -770,5 +773,69 @@ class WanVAE: ...@@ -770,5 +773,69 @@ class WanVAE:
for u in videos for u in videos
] ]
def decode_dist(self, zs, world_size, cur_rank, split_dim):
splited_total_len = zs.shape[split_dim]
splited_chunk_len = splited_total_len // world_size
padding_size = 1
if cur_rank == 0:
if split_dim == 2:
zs = zs[:,:,:splited_chunk_len+2*padding_size,:].contiguous()
elif split_dim == 3:
zs = zs[:,:,:,:splited_chunk_len+2*padding_size].contiguous()
elif cur_rank == world_size-1:
if split_dim == 2:
zs = zs[:,:,-(splited_chunk_len+2*padding_size):,:].contiguous()
elif split_dim == 3:
zs = zs[:,:,:,-(splited_chunk_len+2*padding_size):].contiguous()
else:
if split_dim == 2:
zs = zs[:,:,cur_rank*splited_chunk_len-padding_size:(cur_rank+1)*splited_chunk_len+padding_size,:].contiguous()
elif split_dim == 3:
zs = zs[:,:,:,cur_rank*splited_chunk_len-padding_size:(cur_rank+1)*splited_chunk_len+padding_size].contiguous()
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if cur_rank == 0:
if split_dim == 2:
images = images[:,:,:,:splited_chunk_len*8,:].contiguous()
elif split_dim == 3:
images = images[:,:,:,:,:splited_chunk_len*8].contiguous()
elif cur_rank == world_size-1:
if split_dim == 2:
images = images[:,:,:,-splited_chunk_len*8:,:].contiguous()
elif split_dim == 3:
images = images[:,:,:,:,-splited_chunk_len*8:].contiguous()
else:
if split_dim == 2:
images = images[:,:,:,8*padding_size:-8*padding_size,:].contiguous()
elif split_dim == 3:
images = images[:,:,:,:,8*padding_size:-8*padding_size].contiguous()
full_images = [torch.empty_like(images) for _ in range(world_size)]
dist.all_gather(full_images, images)
torch.cuda.synchronize()
images = torch.cat(full_images, dim=-1)
return images
def decode(self, zs, generator, args): def decode(self, zs, generator, args):
return self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) if self.parallel:
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
height, width = zs.shape[2], zs.shape[3]
if width % world_size == 0:
split_dim = 3
images = self.decode_dist(zs, world_size, cur_rank, split_dim)
elif height % world_size == 0:
split_dim = 2
images = self.decode_dist(zs, world_size, cur_rank, split_dim)
else:
print("Fall back to naive decode mode")
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
else:
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
return images
...@@ -55,7 +55,7 @@ def load_models(args, model_config): ...@@ -55,7 +55,7 @@ def load_models(args, model_config):
) )
text_encoders = [text_encoder] text_encoders = [text_encoder]
model = WanModel(args.model_path, model_config) model = WanModel(args.model_path, model_config)
vae_model = WanVAE(vae_pth=os.path.join(args.model_path, "Wan2.1_VAE.pth"), device=torch.device("cuda")) vae_model = WanVAE(vae_pth=os.path.join(args.model_path, "Wan2.1_VAE.pth"), device=torch.device("cuda"), parallel=args.parallel_vae)
if args.task == 'i2v': if args.task == 'i2v':
image_encoder = CLIPModel( image_encoder = CLIPModel(
dtype=torch.float16, dtype=torch.float16,
...@@ -241,6 +241,7 @@ if __name__ == "__main__": ...@@ -241,6 +241,7 @@ if __name__ == "__main__":
parser.add_argument('--mm_config', default=None) parser.add_argument('--mm_config', default=None)
parser.add_argument('--seed', type=int, default=42) parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--parallel_attn', action='store_true') parser.add_argument('--parallel_attn', action='store_true')
parser.add_argument('--parallel_vae', action='store_true')
parser.add_argument('--max_area', action='store_true') parser.add_argument('--max_area', action='store_true')
parser.add_argument('--vae_stride', default=(4, 8, 8)) parser.add_argument('--vae_stride', default=(4, 8, 8))
parser.add_argument('--patch_size', default=(1, 2, 2)) parser.add_argument('--patch_size', default=(1, 2, 2))
...@@ -269,7 +270,8 @@ if __name__ == "__main__": ...@@ -269,7 +270,8 @@ if __name__ == "__main__":
"do_mm_calib": args.do_mm_calib, "do_mm_calib": args.do_mm_calib,
"cpu_offload": args.cpu_offload, "cpu_offload": args.cpu_offload,
"feature_caching": args.feature_caching, "feature_caching": args.feature_caching,
"parallel_attn": args.parallel_attn "parallel_attn": args.parallel_attn,
"parallel_vae": args.parallel_vae
} }
if args.config_path is not None: if args.config_path is not None:
...@@ -308,6 +310,7 @@ if __name__ == "__main__": ...@@ -308,6 +310,7 @@ if __name__ == "__main__":
images = run_vae(latents, generator, args) images = run_vae(latents, generator, args)
if not args.parallel_attn or (args.parallel_attn and dist.get_rank() == 0):
if args.model_cls == "wan2.1": if args.model_cls == "wan2.1":
cache_video(tensor=images, save_file=args.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1)) cache_video(tensor=images, save_file=args.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else: else:
......
#!/bin/bash #!/bin/bash
# model_path=/mnt/nvme1/yongyang/models/hy/ckpts # H800-13 # model_path=/mnt/nvme1/yongyang/models/hy/ckpts # H800-13
model_path=/mnt/nvme0/yongyang/projects/hy/HunyuanVideo/ckpts # H800-14 # model_path=/mnt/nvme0/yongyang/projects/hy/HunyuanVideo/ckpts # H800-14
# model_path=/workspace/ckpts_link # H800-14 model_path=/workspace/ckpts_link # H800-14
# export CUDA_VISIBLE_DEVICES=2
# python main.py \
# --model_cls hunyuan \
# --model_path $model_path \
# --prompt "A cat walks on the grass, realistic style." \
# --infer_steps 20 \
# --target_video_length 33 \
# --target_height 720 \
# --target_width 1280 \
# --attention_type flash_attn3 \
# --save_video_path ./output_lightx2v_int8.mp4 \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
# export CUDA_VISIBLE_DEVICES=0,1,2,3
# torchrun --nproc_per_node=4 main.py \
# --model_cls hunyuan \
# --model_path $model_path \
# --prompt "A cat walks on the grass, realistic style." \
# --infer_steps 20 \
# --target_video_length 33 \
# --target_height 720 \
# --target_width 1280 \
# --attention_type flash_attn2 \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
# --parallel_attn
export CUDA_VISIBLE_DEVICES=2 export CUDA_VISIBLE_DEVICES=2
python main.py \ python main.py \
--model_cls hunyuan \ --model_cls hunyuan \
--model_path $model_path \ --model_path $model_path \
--prompt "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." \ --prompt "A cat walks on the grass, realistic style." \
--infer_steps 50 \ --infer_steps 20 \
--target_video_length 65 \ --target_video_length 33 \
--target_height 480 \ --target_height 720 \
--target_width 640 \ --target_width 1280 \
--attention_type flash_attn3 \ --attention_type flash_attn3 \
--cpu_offload \ --save_video_path ./output_lightx2v_int8.mp4 \
--feature_caching TaylorSeer \ --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
--save_video_path ./output_lightx2v_offload_TaylorSeer.mp4 \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
# model_path=/mnt/nvme1/yongyang/models/hy/ckpts # H800-13
# model_path=/mnt/nvme0/yongyang/projects/hy/HunyuanVideo/ckpts # H800-14
model_path=/workspace/ckpts_link # H800-14
export CUDA_VISIBLE_DEVICES=0,1,2,3
torchrun --nproc_per_node=4 main.py \
--model_cls hunyuan \
--model_path $model_path \
--prompt "A cat walks on the grass, realistic style." \
--infer_steps 20 \
--target_video_length 33 \
--target_height 720 \
--target_width 1280 \
--attention_type flash_attn2 \
--mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
--parallel_attn
\ No newline at end of file
# model_path=/mnt/nvme1/yongyang/models/hy/ckpts # H800-13
# model_path=/mnt/nvme0/yongyang/projects/hy/HunyuanVideo/ckpts # H800-14
model_path=/workspace/ckpts_link # H800-14
export CUDA_VISIBLE_DEVICES=2
python main.py \
--model_cls hunyuan \
--model_path $model_path \
--prompt "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." \
--infer_steps 50 \
--target_video_length 65 \
--target_height 480 \
--target_width 640 \
--attention_type flash_attn2 \
--cpu_offload \
--feature_caching TaylorSeer \
--save_video_path ./output_lightx2v_offload_TaylorSeer.mp4 \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
\ No newline at end of file
#!/bin/bash #!/bin/bash
export CUDA_VISIBLE_DEVICES=2
# model_path=/mnt/nvme1/yongyang/models/hy/ckpts # H800-13 # model_path=/mnt/nvme1/yongyang/models/hy/ckpts # H800-13
# model_path=/workspace/wan/Wan2.1-T2V-1.3B # H800-14
# config_path=/workspace/wan/Wan2.1-T2V-1.3B/config.json
model_path=/mnt/nvme0/yongyang/projects/wan/Wan2.1-T2V-1.3B # H800-14 model_path=/mnt/nvme0/yongyang/projects/wan/Wan2.1-T2V-1.3B # H800-14
config_path=/mnt/nvme0/yongyang/projects/wan/Wan2.1-T2V-1.3B/config.json config_path=/mnt/nvme0/yongyang/projects/wan/Wan2.1-T2V-1.3B/config.json
export CUDA_VISIBLE_DEVICES=0
python main.py \ python main.py \
--model_cls wan2.1 \ --model_cls wan2.1 \
--task t2v \ --task t2v \
...@@ -15,7 +16,7 @@ python main.py \ ...@@ -15,7 +16,7 @@ python main.py \
--target_video_length 81 \ --target_video_length 81 \
--target_width 832 \ --target_width 832 \
--target_height 480 \ --target_height 480 \
--attention_type flash_attn3 \ --attention_type flash_attn2 \
--seed 42 \ --seed 42 \
--sample_neg_promp 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ --sample_neg_promp 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--config_path $config_path \ --config_path $config_path \
......
model_path=/workspace/wan/Wan2.1-T2V-1.3B # H800-14
config_path=/workspace/wan/Wan2.1-T2V-1.3B/config.json
export CUDA_VISIBLE_DEVICES=4,5,6,7
torchrun --nproc_per_node=4 main.py \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--infer_steps 50 \
--target_video_length 84 \
--target_width 832 \
--target_height 480 \
--attention_type flash_attn2 \
--seed 42 \
--sample_neg_promp 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--config_path $config_path \
--save_video_path ./output_lightx2v_seed42.mp4 \
--sample_guide_scale 6 \
--sample_shift 8 \
--parallel_attn \
--parallel_vae
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment