Commit d76fc3db authored by Xinchi Huang's avatar Xinchi Huang Committed by Yang Yong(雍洋)
Browse files

Xinchi/ring attn (#9)



* ring attn & comm

* ring-attn

* ring attn

* ring attn & unit test

* ring attn

* pre-commit reformat

* Update run_wan_i2v.sh

* Update attn.py

---------
Co-authored-by: default avatar“huangxinchi” <“huangxinchi@sensetime.com”>
Co-authored-by: default avatar“de1star” <“843414674@qq.com”>
Co-authored-by: default avatargushiqiao <gushiqiao@sensetime.com>
Co-authored-by: default avatarYang Yong(雍洋) <yongyang1030@163.com>
parent 181f611a
......@@ -28,7 +28,7 @@ from lightx2v.image2v.models.wan.model import CLIPModel
def load_models(args, model_config):
if model_config["parallel_attn"]:
if model_config["parallel_attn_type"]:
cur_rank = dist.get_rank() # 获取当前进程的 rank
torch.cuda.set_device(cur_rank) # 设置当前进程的 CUDA 设备
image_encoder = None
......@@ -221,7 +221,7 @@ if __name__ == "__main__":
parser.add_argument("--feature_caching", choices=["NoCaching", "TaylorSeer", "Tea"], default="NoCaching")
parser.add_argument("--mm_config", default=None)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--parallel_attn", action="store_true")
parser.add_argument("--parallel_attn_type", default=None, choices=["ulysses", "ring"])
parser.add_argument("--parallel_vae", action="store_true")
parser.add_argument("--max_area", action="store_true")
parser.add_argument("--vae_stride", default=(4, 8, 8))
......@@ -235,7 +235,7 @@ if __name__ == "__main__":
seed_all(args.seed)
if args.parallel_attn:
if args.parallel_attn_type:
dist.init_process_group(backend="nccl")
if args.mm_config:
......@@ -251,7 +251,7 @@ if __name__ == "__main__":
"do_mm_calib": args.do_mm_calib,
"cpu_offload": args.cpu_offload,
"feature_caching": args.feature_caching,
"parallel_attn": args.parallel_attn,
"parallel_attn_type": args.parallel_attn_type,
"parallel_vae": args.parallel_vae,
}
......@@ -291,7 +291,7 @@ if __name__ == "__main__":
images = run_vae(latents, generator, args)
if not args.parallel_attn or (args.parallel_attn and dist.get_rank() == 0):
if not args.parallel_attn_type or (args.parallel_attn_type and dist.get_rank() == 0):
if args.model_cls == "wan2.1":
cache_video(tensor=images, save_file=args.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else:
......
from typing import Optional
import torch
import torch.distributed as dist
class RingComm:
def __init__(self, process_group: dist.ProcessGroup = None):
self._process_group = process_group
self._ops = []
self.rank = dist.get_rank(self._process_group)
self.world_size = dist.get_world_size(self._process_group)
self._reqs = None
self.send_rank = (self.rank + 1) % self.world_size
self.recv_rank = (self.rank - 1) % self.world_size
if process_group is not None:
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
if recv_tensor is None:
res = torch.empty_like(to_send)
# print(f"send_recv: empty_like {to_send.shape}")
else:
res = recv_tensor
send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group)
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
self._ops.append(send_op)
self._ops.append(recv_op)
return res
def commit(self):
if self._reqs is not None:
raise RuntimeError("commit called twice")
self._reqs = dist.batch_isend_irecv(self._ops)
def wait(self):
if self._reqs is None:
raise RuntimeError("wait called before commit")
for req in self._reqs:
req.wait()
self._reqs = None
self._ops = []
import torch
import torch.distributed as dist
from lightx2v.attentions import attention
import torch.nn.functional as F
# from lightx2v.attentions import attention
from lightx2v.attentions.distributed.comm.ring_comm import RingComm
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
from typing import Optional, Tuple
# RING_COMM = None
# def init_ring_comm():
# global RING_COMM
# RING_COMM = RingComm()
@torch.jit.script
def _update_out_and_lse(
out: torch.Tensor,
lse: torch.Tensor,
block_out: torch.Tensor,
block_lse: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
block_out = block_out.to(torch.float32)
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
# new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
# torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
# For additional context and discussion, please refer to:
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
lse = lse - F.logsigmoid(lse - block_lse)
return out, lse
def update_out_and_lse(
out: Optional[torch.Tensor],
lse: Optional[torch.Tensor],
block_out: torch.Tensor,
block_lse: torch.Tensor,
slice_=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if out is None:
if slice_ is not None:
raise RuntimeError("first update_out_and_lse should not pass slice_ args")
out = block_out.to(torch.float32)
lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
elif slice_ is not None:
slice_out, slice_lse = out[slice_], lse[slice_]
slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse)
out[slice_], lse[slice_] = slice_out, slice_lse
else:
out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
return out, lse
def ring_attn_sub(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if flash_attn.__version__ < "2.6.3":
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
q,
k,
v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax,
)
else:
block_out, block_lse, _, _ = _flash_attn_forward(
q,
k,
v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax,
)
return block_out, block_lse
def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"):
"""
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
执行 Ring 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
......@@ -22,49 +110,79 @@ def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
# 获取序列长度和文本相关的长度
seq_len = q.shape[0]
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
# 获取查询张量的头数和隐藏维度
_, heads, hidden_dims = q.shape
shard_heads = heads // world_size # 每个进程处理的头数
shard_seqlen = img_qkv_len # 每个进程处理的序列长度
# 分割图像和文本的查询、键和值
img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
gathered_img_k = [torch.empty_like(img_k) for _ in range(world_size)]
gathered_img_v = [torch.empty_like(img_v) for _ in range(world_size)]
dist.all_gather(gathered_img_k, img_k)
dist.all_gather(gathered_img_v, img_v)
torch.cuda.synchronize()
q = q
k = torch.cat(gathered_img_k + [txt_k], dim=0)
v = torch.cat(gathered_img_v + [txt_v], dim=0)
# 初始化累积序列长度张量
cu_seqlens_q = torch.zeros([3], dtype=torch.int32, device="cuda")
s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置
s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置
cu_seqlens_q[1] = s1 # 设置累积序列长度
cu_seqlens_q[2] = s2 # 设置累积序列长度
max_seqlen_q = img_q.shape[0] + txt_q.shape[0] # 最大序列长度
# 初始化累积序列长度张量
cu_seqlens_kv = torch.zeros([3], dtype=torch.int32, device="cuda")
s = txt_qkv_len + img_k.shape[0] * world_size # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置
s2 = txt_mask_len + img_k.shape[0] * world_size # 文本掩码的结束位置
cu_seqlens_kv[1] = s1 # 设置累积序列长度
cu_seqlens_kv[2] = s2 # 设置累积序列长度
max_seqlen_kv = img_k.shape[0] * world_size + txt_q.shape[0] # 最大序列长度
attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv)
return attn
if len(cu_seqlens_qkv) == 3:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
elif len(cu_seqlens_qkv) == 2:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = 0
# if RING_COMM is None:
# init_ring_comm()
RING_COMM = RingComm()
# if len(cu_seqlens_qkv) == 3:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
# elif len(cu_seqlens_qkv) == 2:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = None
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
img_q, img_k, img_v = q[:, :img_qkv_len, :, :].contiguous(), k[:, :img_qkv_len, :, :].contiguous(), v[:, :img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = (
q[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
k[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
v[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
)
out, lse, next_k, next_v = None, None, None, None
if len(cu_seqlens_qkv) == 3:
q = torch.cat((img_q, txt_q), dim=1)
k = img_k
v = img_v
for step in range(world_size):
if step + 1 != world_size:
next_k = RING_COMM.send_recv(k)
next_v = RING_COMM.send_recv(v)
RING_COMM.commit()
if step + 1 == world_size:
k = torch.cat((k, txt_k), dim=1)
v = torch.cat((v, txt_v), dim=1)
block_out, block_lse = ring_attn_sub(q, k, v)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
if step + 1 != world_size:
RING_COMM.wait()
k = next_k
v = next_v
attn1 = out.to(torch.bfloat16).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1)
if txt_mask_len > 0:
attn2, *_ = _flash_attn_forward(
q[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
k[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
v[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
dropout_p=0.0,
softmax_scale=q.shape[-1] ** (-0.5),
causal=False,
window_size_left=-1,
window_size_right=-1,
softcap=0.0,
alibi_slopes=None,
return_softmax=False,
)
attn2 = attn2.to(torch.bfloat16).squeeze(0).reshape((txt_mask_len - txt_qkv_len), -1)
attn1 = torch.cat([attn1, attn2], dim=0)
return attn1
import torch
import torch.distributed as dist
from lightx2v.attentions import attention
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
from lightx2v.attentions.distributed.ring.attn import ring_attn_sub, update_out_and_lse
from lightx2v.attentions.distributed.comm.ring_comm import RingComm
RING_COMM = None
def init_ring_comm():
global RING_COMM
RING_COMM = RingComm()
def base_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, lq, lk):
attn_out = attention(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
)
return attn_out
def ring_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, lq, lk, ring_size):
out, lse = None, None
# q = torch.chunk(q, ring_size)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
k = torch.chunk(k, ring_size, dim=1)
v = torch.chunk(v, ring_size, dim=1)
for i in range(ring_size):
k_block, v_block = k[i], v[i]
block_out, block_lse = ring_attn_sub(q, k_block, v_block)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
attn_out = out.to(torch.bfloat16).squeeze(0).reshape(lq, -1)
return attn_out
def ring_attention_dist(q, k, v, cu_seqlens_q, cu_seqlens_k, lq, lk):
if RING_COMM is None:
init_ring_comm()
out, lse = None, None
# q = torch.chunk(q, ring_size)
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
out, lse, next_k, next_v = None, None, None, None
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
k = torch.chunk(k, world_size, dim=1)[cur_rank]
v = torch.chunk(v, world_size, dim=1)[cur_rank]
for step in range(world_size):
if step + 1 != world_size:
next_k = RING_COMM.send_recv(k)
next_v = RING_COMM.send_recv(v)
RING_COMM.commit()
block_out, block_lse = ring_attn_sub(q, k, v)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
if step + 1 != world_size:
RING_COMM.wait()
k = next_k
v = next_v
attn_out = out.to(torch.bfloat16).squeeze(0).reshape(lq, -1)
return attn_out
def test():
q = torch.randn((32760, 12, 128), dtype=torch.bfloat16, device="cuda")
k = torch.randn((32760, 12, 128), dtype=torch.bfloat16, device="cuda")
v = torch.randn((32760, 12, 128), dtype=torch.bfloat16, device="cuda")
cu_seqlens_q = torch.tensor([0, 32760], dtype=torch.int32, device="cuda")
cu_seqlens_k = torch.tensor([0, 32760], dtype=torch.int32, device="cuda")
lq = 32760
lk = 32760
base_attn = base_attention(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, lq=lq, lk=lk)
ring_attn = ring_attention(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, lq=lq, lk=lk, ring_size=4)
# import pdb; pdb.set_trace()
# 添加断言以确认数值相同
assert torch.allclose(base_attn, ring_attn, rtol=1e-3, atol=1e-3), "base_attn 和 ring_attn 的数值不相同!"
if __name__ == "__main__":
# dist.init_process_group(backend="nccl")
test()
export PYTHONPATH=/home/devsft/huangxinchi/lightx2v:$PYTHONPATH
python3 test.py
import functools
from lightx2v.attentions.distributed.ring.attn import ring_attn
from lightx2v.attentions.distributed.utils.process import pre_process, post_process
def parallelize_hunyuan(hunyuan_model):
from lightx2v.attentions.distributed.utils.hunyuan.processor import pre_process, post_process
"""将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。
参数:
......@@ -16,34 +17,55 @@ def parallelize_hunyuan(hunyuan_model):
original_infer = hunyuan_model.infer
@functools.wraps(hunyuan_model.__class__.infer) # 保留原始推理方法的元信息
def new_infer(self, latent_model_input, t_expand, text_states, text_mask, text_states_2, freqs_cos, freqs_sin, guidance):
def new_infer(self, text_encoders_output, image_encoder_output, args):
"""新的推理方法,处理输入并调用原始推理方法。
参数:
self: Hunyuan 模型实例
latent_model_input: 潜在模型输入
t_expand: 时间扩展参数
text_states: 文本状态
text_mask: 文本掩码
text_states_2: 第二组文本状态
freqs_cos: 余弦频率
freqs_sin: 正弦频率
guidance: 指导参数
text_encoders_output: 文本编码器的输出
args: 其他参数
返回:
combined_output: 经过后处理的输出结果
None
"""
# 预处理输入数据
latent_model_input, freqs_cos, freqs_sin, split_dim = pre_process(latent_model_input, freqs_cos, freqs_sin)
# 保存原始的潜在模型输入和频率数据
self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin = (self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin)
# 预处理输入数据以适应并行计算
self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin, split_dim = pre_process(self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin)
# 调用原始推理方法,获取输出
output = original_infer(latent_model_input, t_expand, text_states, text_mask, text_states_2, freqs_cos, freqs_sin, guidance)
original_infer(text_encoders_output, image_encoder_output, args)
# 对输出进行后处理
combined_output = post_process(output, split_dim)
self.scheduler.noise_pred = post_process(self.scheduler.noise_pred, split_dim)
# 恢复原始的潜在模型输入和频率数据
self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin = (self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin)
return combined_output # 返回处理后的输出
# return combined_output # 返回处理后的输出(当前被注释掉)
# 将新的推理方法绑定到 Hunyuan 模型实例
new_infer = new_infer.__get__(hunyuan_model)
hunyuan_model.infer = new_infer # 替换原始推理方法
def parallelize_wan(wan_model):
from lightx2v.attentions.distributed.utils.wan.processor import pre_process, post_process
wan_model.transformer_infer.parallel_attention = ring_attn
original_infer = wan_model.transformer_infer.infer
@functools.wraps(wan_model.transformer_infer.__class__.infer) # 保留原始推理方法的元信息
def new_infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
x = pre_process(x)
x = original_infer(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
x = post_process(x)
return x
new_infer = new_infer.__get__(wan_model.transformer_infer)
wan_model.transformer_infer.infer = new_infer # 替换原始推理方法
......@@ -8,8 +8,8 @@ from lightx2v.text2v.models.networks.hunyuan.infer.post_infer import HunyuanPost
from lightx2v.text2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.text2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import HunyuanTransformerInferFeatureCaching
# from lightx2v.core.distributed.partial_heads_attn.wrap import parallelize_hunyuan
from lightx2v.attentions.distributed.ulysses.wrap import parallelize_hunyuan
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
class HunyuanModel:
......@@ -24,8 +24,13 @@ class HunyuanModel:
self._init_weights()
self._init_infer()
if self.config["parallel_attn"]:
parallelize_hunyuan(self)
if config["parallel_attn_type"]:
if config["parallel_attn_type"] == "ulysses":
ulysses_dist_wrap.parallelize_hunyuan(self)
elif config["parallel_attn_type"] == "ring":
ring_dist_wrap.parallelize_hunyuan(self)
else:
raise Exception(f"Unsuppotred parallel_attn_type")
if self.config["cpu_offload"]:
self.to_cpu()
......
......@@ -14,7 +14,8 @@ from lightx2v.text2v.models.networks.wan.infer.transformer_infer import (
)
from lightx2v.text2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferFeatureCaching
from safetensors import safe_open
from lightx2v.attentions.distributed.ulysses.wrap import parallelize_wan
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
class WanModel:
......@@ -29,8 +30,13 @@ class WanModel:
self._init_weights()
self._init_infer()
if config["parallel_attn"]:
parallelize_wan(self)
if config["parallel_attn_type"]:
if config["parallel_attn_type"] == "ulysses":
ulysses_dist_wrap.parallelize_wan(self)
elif config["parallel_attn_type"] == "ring":
ring_dist_wrap.parallelize_wan(self)
else:
raise Exception(f"Unsuppotred parallel_attn_type")
if self.config["cpu_offload"]:
self.to_cpu()
......
#!/bin/bash
lightx2v_path=/mtc/yongyang/projects/lightx2v
lightx2v_path=/home/devsft/huangxinchi/lightx2v
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0,1,2,3
export CUDA_VISIBLE_DEVICES=4,5,6,7
model_path=/mtc/yongyang/models/x2v_models/hunyuan/lightx2v_format/t2v
......@@ -18,4 +18,18 @@ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--target_width 1280 \
--attention_type flash_attn2 \
--mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
--parallel_attn
--parallel_attn_type ulysses \
--save_video_path ./output_lightx2v_hunyuan_t2v_dist_ulysses.mp4
torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \
--model_path $model_path \
--prompt "A cat walks on the grass, realistic style." \
--infer_steps 20 \
--target_video_length 33 \
--target_height 720 \
--target_width 1280 \
--attention_type flash_attn2 \
--mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
--parallel_attn_type ring \
--save_video_path ./output_lightx2v_hunyuan_t2v_dist_ring.mp4
#!/bin/bash
lightx2v_path=/mtc/yongyang/projects/lightx2v
lightx2v_path=/home/devsft/huangxinchi/lightx2v
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0,1,2,3
export CUDA_VISIBLE_DEVICES=4,5,6,7
model_path=/mtc/yongyang/models/x2v_models/wan/Wan2.1-T2V-1.3B
config_path=/mtc/yongyang/models/x2v_models/wan/Wan2.1-T2V-1.3B/config.json
......@@ -25,5 +25,26 @@ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--save_video_path ./output_lightx2v_seed42.mp4 \
--sample_guide_scale 6 \
--sample_shift 8 \
--parallel_attn \
--parallel_vae
--parallel_attn_type ring \
--parallel_vae \
--save_video_path ./output_lightx2v_wan_t2v_dist_ring.mp4
torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--infer_steps 50 \
--target_video_length 81 \
--target_width 832 \
--target_height 480 \
--attention_type flash_attn2 \
--seed 42 \
--sample_neg_promp 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--config_path $config_path \
--save_video_path ./output_lightx2v_seed42.mp4 \
--sample_guide_scale 6 \
--sample_shift 8 \
--parallel_attn_type ulysses \
--parallel_vae \
--save_video_path ./output_lightx2v_wan_t2v_dist_ulysses.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment