Commit 1a881d63 authored by helloyongyang's avatar helloyongyang
Browse files

重构并行模块

parent 18e2b23a
import functools
from lightx2v.attentions.distributed.ulysses.attn import ulysses_attn
def parallelize_hunyuan(hunyuan_model):
from lightx2v.attentions.distributed.utils.hunyuan.processor import pre_process, post_process
"""将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。
参数:
hunyuan_model: Hunyuan 模型实例,包含推理方法和其他属性。
"""
# 将 Hunyuan 模型的并行注意力机制替换为 Ulysses 注意力
hunyuan_model.transformer_infer.parallel_attention = ulysses_attn
# 保存原始的推理方法,以便后续调用
original_infer = hunyuan_model.infer
@functools.wraps(hunyuan_model.__class__.infer) # 保留原始推理方法的元信息
def new_infer(self, text_encoders_output, image_encoder_output, args):
"""新的推理方法,处理输入并调用原始推理方法。
参数:
self: Hunyuan 模型实例
text_encoders_output: 文本编码器的输出
args: 其他参数
返回:
None
"""
# 保存原始的潜在模型输入和频率数据
self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin = (self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin)
# 预处理输入数据以适应并行计算
self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin, split_dim = pre_process(self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin)
# 调用原始推理方法,获取输出
original_infer(text_encoders_output, image_encoder_output, args)
# 对输出进行后处理
self.scheduler.noise_pred = post_process(self.scheduler.noise_pred, split_dim)
# 恢复原始的潜在模型输入和频率数据
self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin = (self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin)
# return combined_output # 返回处理后的输出(当前被注释掉)
# 将新的推理方法绑定到 Hunyuan 模型实例
new_infer = new_infer.__get__(hunyuan_model)
hunyuan_model.infer = new_infer # 替换原始推理方法
def parallelize_wan(wan_model):
from lightx2v.attentions.distributed.utils.wan.processor import pre_process, post_process
wan_model.transformer_infer.parallel_attention = ulysses_attn
original_infer = wan_model.transformer_infer.infer
@functools.wraps(wan_model.transformer_infer.__class__.infer) # 保留原始推理方法的元信息
def new_infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
x = pre_process(x)
x = original_infer(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
x = post_process(x)
return x
new_infer = new_infer.__get__(wan_model.transformer_infer)
wan_model.transformer_infer.infer = new_infer # 替换原始推理方法
import torch
import torch.distributed as dist
def pre_process(latent_model_input, freqs_cos, freqs_sin):
"""
对输入的潜在模型数据和频率数据进行预处理,进行切分以适应分布式计算。
参数:
latent_model_input (torch.Tensor): 输入的潜在模型数据,形状为 [batch_size, channels, temporal_size, height, width]
freqs_cos (torch.Tensor): 余弦频率数据,形状为 [batch_size, channels, temporal_size, height, width]
freqs_sin (torch.Tensor): 正弦频率数据,形状为 [batch_size, channels, temporal_size, height, width]
返回:
tuple: 处理后的 latent_model_input, freqs_cos, freqs_sin 和切分维度 split_dim
"""
# 获取当前进程的世界大小和当前进程的排名
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
# 根据输入的形状确定切分维度
if latent_model_input.shape[-2] // 2 % world_size == 0:
split_dim = -2 # 按高度切分
elif latent_model_input.shape[-1] // 2 % world_size == 0:
split_dim = -1 # 按宽度切分
else:
raise ValueError(f"Cannot split video sequence into world size ({world_size}) parts evenly")
# 获取时间维度、处理后的高度和宽度
temporal_size, h, w = latent_model_input.shape[2], latent_model_input.shape[3] // 2, latent_model_input.shape[4] // 2
# 按照确定的维度切分潜在模型输入
latent_model_input = torch.chunk(latent_model_input, world_size, dim=split_dim)[cur_rank]
# 处理余弦频率数据
dim_thw = freqs_cos.shape[-1] # 获取频率数据的最后一个维度
freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw) # 重塑为 [temporal_size, height, width, dim_thw]
freqs_cos = torch.chunk(freqs_cos, world_size, dim=split_dim - 1)[cur_rank] # 切分频率数据
freqs_cos = freqs_cos.reshape(-1, dim_thw) # 重塑为 [batch_size, dim_thw]
# 处理正弦频率数据
dim_thw = freqs_sin.shape[-1] # 获取频率数据的最后一个维度
freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw) # 重塑为 [temporal_size, height, width, dim_thw]
freqs_sin = torch.chunk(freqs_sin, world_size, dim=split_dim - 1)[cur_rank] # 切分频率数据
freqs_sin = freqs_sin.reshape(-1, dim_thw) # 重塑为 [batch_size, dim_thw]
return latent_model_input, freqs_cos, freqs_sin, split_dim # 返回处理后的数据
def post_process(output, split_dim):
"""对输出进行后处理,收集所有进程的输出并合并。
参数:
output (torch.Tensor): 当前进程的输出,形状为 [batch_size, ...]
split_dim (int): 切分维度,用于合并输出
返回:
torch.Tensor: 合并后的输出,形状为 [world_size * batch_size, ...]
"""
# 获取当前进程的世界大小
world_size = dist.get_world_size()
# 创建一个列表,用于存储所有进程的输出
gathered_outputs = [torch.empty_like(output) for _ in range(world_size)]
# 收集所有进程的输出
dist.all_gather(gathered_outputs, output)
# 在指定的维度上合并所有进程的输出
combined_output = torch.cat(gathered_outputs, dim=split_dim)
return combined_output # 返回合并后的输出
import torch
import torch.distributed as dist
def pre_process(latent_model_input, freqs_cos, freqs_sin):
"""
对输入的潜在模型数据和频率数据进行预处理,进行切分以适应分布式计算。
参数:
latent_model_input (torch.Tensor): 输入的潜在模型数据,形状为 [batch_size, channels, temporal_size, height, width]
freqs_cos (torch.Tensor): 余弦频率数据,形状为 [batch_size, channels, temporal_size, height, width]
freqs_sin (torch.Tensor): 正弦频率数据,形状为 [batch_size, channels, temporal_size, height, width]
返回:
tuple: 处理后的 latent_model_input, freqs_cos, freqs_sin 和切分维度 split_dim
"""
# 获取当前进程的世界大小和当前进程的排名
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
# 根据输入的形状确定切分维度
if latent_model_input.shape[-2] // 2 % world_size == 0:
split_dim = -2 # 按高度切分
elif latent_model_input.shape[-1] // 2 % world_size == 0:
split_dim = -1 # 按宽度切分
else:
raise ValueError(f"Cannot split video sequence into world size ({world_size}) parts evenly")
# 获取时间维度、处理后的高度和宽度
temporal_size, h, w = latent_model_input.shape[2], latent_model_input.shape[3] // 2, latent_model_input.shape[4] // 2
# 按照确定的维度切分潜在模型输入
latent_model_input = torch.chunk(latent_model_input, world_size, dim=split_dim)[cur_rank]
# 处理余弦频率数据
dim_thw = freqs_cos.shape[-1] # 获取频率数据的最后一个维度
freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw) # 重塑为 [temporal_size, height, width, dim_thw]
freqs_cos = torch.chunk(freqs_cos, world_size, dim=split_dim - 1)[cur_rank] # 切分频率数据
freqs_cos = freqs_cos.reshape(-1, dim_thw) # 重塑为 [batch_size, dim_thw]
# 处理正弦频率数据
dim_thw = freqs_sin.shape[-1] # 获取频率数据的最后一个维度
freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw) # 重塑为 [temporal_size, height, width, dim_thw]
freqs_sin = torch.chunk(freqs_sin, world_size, dim=split_dim - 1)[cur_rank] # 切分频率数据
freqs_sin = freqs_sin.reshape(-1, dim_thw) # 重塑为 [batch_size, dim_thw]
return latent_model_input, freqs_cos, freqs_sin, split_dim # 返回处理后的数据
def post_process(output, split_dim):
"""对输出进行后处理,收集所有进程的输出并合并。
参数:
output (torch.Tensor): 当前进程的输出,形状为 [batch_size, ...]
split_dim (int): 切分维度,用于合并输出
返回:
torch.Tensor: 合并后的输出,形状为 [world_size * batch_size, ...]
"""
# 获取当前进程的世界大小
world_size = dist.get_world_size()
# 创建一个列表,用于存储所有进程的输出
gathered_outputs = [torch.empty_like(output) for _ in range(world_size)]
# 收集所有进程的输出
dist.all_gather(gathered_outputs, output)
# 在指定的维度上合并所有进程的输出
combined_output = torch.cat(gathered_outputs, dim=split_dim)
return combined_output # 返回合并后的输出
from re import split
import torch
import torch.distributed as dist
import torch.nn.functional as F
PADDING_SIZE = None
def pre_process(x):
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
padding_size = (world_size - (x.shape[0] % world_size)) % world_size
if padding_size > 0:
# 使用 F.pad 填充第一维
x = F.pad(x, (0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
x = torch.chunk(x, world_size, dim=0)[cur_rank]
return x
def post_process(x):
# 获取当前进程的世界大小
world_size = dist.get_world_size()
# 创建一个列表,用于存储所有进程的输出
gathered_x = [torch.empty_like(x) for _ in range(world_size)]
# 收集所有进程的输出
dist.all_gather(gathered_x, x)
# 在指定的维度上合并所有进程的输出
combined_output = torch.cat(gathered_x, dim=0)
return combined_output # 返回合并后的输出
from .attn_weight import *
from .flash_attn import *
from .radial_attn import *
from .ring_attn import *
from .sage_attn import *
from .torch_sdpa import *
from .ulysses_attn import *
from .sparge_attn import *
import torch
import torch.nn as nn
from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
import torch.nn.functional as F
from loguru import logger
try:
from spas_sage_attn.autotune import SparseAttentionMeansim
except ImportError:
logger.info("SparseAttentionMeansim not found, please install sparge first")
SparseAttentionMeansim = None
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
flash_attn_varlen_func = None
try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
flash_attn_varlen_func_v3 = None
if torch.cuda.get_device_capability(0) == (8, 9):
try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError:
logger.info("sageattn not found, please install sageattention first")
sageattn = None
else:
try:
from sageattention import sageattn
except ImportError:
logger.info("sageattn not found, please install sageattention first")
sageattn = None
from lightx2v.attentions.common.radial_attn import radial_attn
class AttnWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name):
self.weight_name = weight_name
self.config = {}
def load(self, weight_dict):
pass
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
def to_cpu(self, non_blocking=False):
pass
def to_cuda(self, non_blocking=False):
pass
def state_dict(self, destination=None):
if destination is None:
destination = {}
return destination
@ATTN_WEIGHT_REGISTER("flash_attn2")
class FlashAttn2Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
x = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("flash_attn3")
class FlashAttn3Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
x = flash_attn_varlen_func_v3(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
)[0].reshape(max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("radial_attn")
class RadialAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
mask_map=None,
sparsity_type="radial",
block_size=128,
decay_factor=1,
model_cls="wan",
):
assert len(q.shape) == 3
x = radial_attn(
q,
k,
v,
mask_map=mask_map,
sparsity_type=sparsity_type,
block_size=block_size,
model_cls=model_cls[:3], # Use first 3 characters to match "wan", "wan2", etc.
decay_factor=decay_factor,
)
x = x.view(max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
if model_cls == "hunyuan":
x1 = sageattn(
q[: cu_seqlens_q[1]].unsqueeze(0),
k[: cu_seqlens_kv[1]].unsqueeze(0),
v[: cu_seqlens_kv[1]].unsqueeze(0),
tensor_layout="NHD",
)
x2 = sageattn(
q[cu_seqlens_q[1] :].unsqueeze(0),
k[cu_seqlens_kv[1] :].unsqueeze(0),
v[cu_seqlens_kv[1] :].unsqueeze(0),
tensor_layout="NHD",
)
x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df"]:
x = sageattn(
q.unsqueeze(0),
k.unsqueeze(0),
v.unsqueeze(0),
tensor_layout="NHD",
)
x = x.view(max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("torch_sdpa")
class TorchSDPAWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
drop_rate=0,
attn_mask=None,
causal=False,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
x = x.transpose(1, 2)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out.squeeze(0)
@ATTN_WEIGHT_REGISTER("Sparge")
class SpargeAttnWeight(AttnWeightTemplate):
def __init__(
self,
weight_name,
verbose=False,
l1=0.07,
pv_l1=0.08,
tune_pv=True,
inner_attn_type="flash_attn3",
):
self.verbose = (verbose,)
self.l1 = (l1,)
self.pv_l1 = (pv_l1,)
self.tune_pv = (tune_pv,)
self.inner_attn_type = inner_attn_type
self.inner_cls = SparseAttentionMeansim(l1=l1, pv_l1=pv_l1, tune_pv=tune_pv)
super().__init__(weight_name)
def load(self, weight_dict):
# match all key with prefix weight_name
for key in weight_dict.keys():
if key.startswith(self.weight_name):
sub_name = key.split(".")[-1]
setattr(
self.inner_cls,
sub_name,
nn.Parameter(weight_dict[key], requires_grad=False),
)
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
if len(q.shape) == 3:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
x = self.inner_cls(q, k, v, tensor_layout="NHD")
x = x.flatten(2)
x = x.squeeze(0)
return x
try:
import flash_attn
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
flash_attn_varlen_func = None
try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
flash_attn_varlen_func_v3 = None
from .template import AttnWeightTemplate
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
@ATTN_WEIGHT_REGISTER("flash_attn2")
class FlashAttn2Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
x = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("flash_attn3")
class FlashAttn3Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
x = flash_attn_varlen_func_v3(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
)[0].reshape(max_seqlen_q, -1)
return x
import torch
from .template import AttnWeightTemplate
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
try:
import flashinfer
......@@ -15,6 +17,42 @@ except ImportError:
###
@ATTN_WEIGHT_REGISTER("radial_attn")
class RadialAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
mask_map=None,
sparsity_type="radial",
block_size=128,
decay_factor=1,
model_cls="wan",
):
assert len(q.shape) == 3
x = radial_attn(
q,
k,
v,
mask_map=mask_map,
sparsity_type=sparsity_type,
block_size=block_size,
model_cls=model_cls[:3], # Use first 3 characters to match "wan", "wan2", etc.
decay_factor=decay_factor,
)
x = x.view(max_seqlen_q, -1)
return x
def radial_attn(
query, key, value, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_cls="wan"
):
......
import torch
from .template import AttnWeightTemplate
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
import torch.distributed as dist
from .utils.ring_comm import RingComm
import torch.nn.functional as F
try:
import flash_attn
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
flash_attn_varlen_func = None
@torch.jit.script
def _update_out_and_lse(
out,
lse,
block_out,
block_lse,
):
block_out = block_out.to(torch.float32)
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
# new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
# torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
# For additional context and discussion, please refer to:
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
lse = lse - F.logsigmoid(lse - block_lse)
return out, lse
@ATTN_WEIGHT_REGISTER("ring")
class RingAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None):
"""
执行 Ring 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
# 获取当前进程的排名和全局进程数
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
if len(cu_seqlens_qkv) == 3:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
elif len(cu_seqlens_qkv) == 2:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = 0
# if RING_COMM is None:
# init_ring_comm()
RING_COMM = RingComm()
# if len(cu_seqlens_qkv) == 3:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
# elif len(cu_seqlens_qkv) == 2:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = None
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
img_q, img_k, img_v = q[:, :img_qkv_len, :, :].contiguous(), k[:, :img_qkv_len, :, :].contiguous(), v[:, :img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = (
q[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
k[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
v[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
)
out, lse, next_k, next_v = None, None, None, None
if len(cu_seqlens_qkv) == 3:
q = torch.cat((img_q, txt_q), dim=1)
k = img_k
v = img_v
for step in range(world_size):
if step + 1 != world_size:
next_k = RING_COMM.send_recv(k)
next_v = RING_COMM.send_recv(v)
RING_COMM.commit()
if step + 1 == world_size:
k = torch.cat((k, txt_k), dim=1)
v = torch.cat((v, txt_v), dim=1)
block_out, block_lse = self.ring_attn_sub(q, k, v)
out, lse = self.update_out_and_lse(out, lse, block_out, block_lse)
if step + 1 != world_size:
RING_COMM.wait()
k = next_k
v = next_v
attn1 = out.to(torch.bfloat16).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1)
if txt_mask_len > 0:
attn2, *_ = flash_attn.flash_attn_interface._flash_attn_forward(
q[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
k[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
v[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
dropout_p=0.0,
softmax_scale=q.shape[-1] ** (-0.5),
causal=False,
window_size_left=-1,
window_size_right=-1,
softcap=0.0,
alibi_slopes=None,
return_softmax=False,
)
attn2 = attn2.to(torch.bfloat16).squeeze(0).reshape((txt_mask_len - txt_qkv_len), -1)
attn1 = torch.cat([attn1, attn2], dim=0)
return attn1
def ring_attn_sub(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
block_out, block_lse, _, _ = flash_attn.flash_attn_interface._flash_attn_forward(
q,
k,
v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax,
)
return block_out, block_lse
def update_out_and_lse(
self,
out,
lse,
block_out,
block_lse,
slice_=None,
):
if out is None:
if slice_ is not None:
raise RuntimeError("first update_out_and_lse should not pass slice_ args")
out = block_out.to(torch.float32)
lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
elif slice_ is not None:
slice_out, slice_lse = out[slice_], lse[slice_]
slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse)
out[slice_], lse[slice_] = slice_out, slice_lse
else:
out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
return out, lse
import torch
from .template import AttnWeightTemplate
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
if torch.cuda.get_device_capability(0) == (8, 9):
try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError:
logger.info("sageattn not found, please install sageattention first")
sageattn = None
else:
try:
from sageattention import sageattn
except ImportError:
logger.info("sageattn not found, please install sageattention first")
sageattn = None
@ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
if model_cls == "hunyuan":
x1 = sageattn(
q[: cu_seqlens_q[1]].unsqueeze(0),
k[: cu_seqlens_kv[1]].unsqueeze(0),
v[: cu_seqlens_kv[1]].unsqueeze(0),
tensor_layout="NHD",
)
x2 = sageattn(
q[cu_seqlens_q[1] :].unsqueeze(0),
k[cu_seqlens_kv[1] :].unsqueeze(0),
v[cu_seqlens_kv[1] :].unsqueeze(0),
tensor_layout="NHD",
)
x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df"]:
x = sageattn(
q.unsqueeze(0),
k.unsqueeze(0),
v.unsqueeze(0),
tensor_layout="NHD",
)
x = x.view(max_seqlen_q, -1)
return x
import torch
from .template import AttnWeightTemplate
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from loguru import logger
import torch.nn as nn
try:
from spas_sage_attn.autotune import SparseAttentionMeansim
except ImportError:
logger.info("SparseAttentionMeansim not found, please install sparge first")
SparseAttentionMeansim = None
@ATTN_WEIGHT_REGISTER("Sparge")
class SpargeAttnWeight(AttnWeightTemplate):
def __init__(
self,
weight_name,
verbose=False,
l1=0.07,
pv_l1=0.08,
tune_pv=True,
inner_attn_type="flash_attn3",
):
self.verbose = (verbose,)
self.l1 = (l1,)
self.pv_l1 = (pv_l1,)
self.tune_pv = (tune_pv,)
self.inner_attn_type = inner_attn_type
self.inner_cls = SparseAttentionMeansim(l1=l1, pv_l1=pv_l1, tune_pv=tune_pv)
super().__init__(weight_name)
def load(self, weight_dict):
# match all key with prefix weight_name
for key in weight_dict.keys():
if key.startswith(self.weight_name):
sub_name = key.split(".")[-1]
setattr(
self.inner_cls,
sub_name,
nn.Parameter(weight_dict[key], requires_grad=False),
)
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
if len(q.shape) == 3:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
x = self.inner_cls(q, k, v, tensor_layout="NHD")
x = x.flatten(2)
x = x.squeeze(0)
return x
from abc import ABCMeta, abstractmethod
class AttnWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name):
self.weight_name = weight_name
self.config = {}
def load(self, weight_dict):
pass
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
def to_cpu(self, non_blocking=False):
pass
def to_cuda(self, non_blocking=False):
pass
def state_dict(self, destination=None):
if destination is None:
destination = {}
return destination
import torch
import torch.nn.functional as F
from .template import AttnWeightTemplate
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
@ATTN_WEIGHT_REGISTER("torch_sdpa")
class TorchSDPAWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
drop_rate=0,
attn_mask=None,
causal=False,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
if q.ndim == 3:
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
x = x.transpose(1, 2)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out.squeeze(0)
import torch
from .template import AttnWeightTemplate
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
import torch.distributed as dist
from .utils.all2all import all2all_seq2head, all2all_head2seq
@ATTN_WEIGHT_REGISTER("ulysses")
class UlyssesAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None):
"""
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
# 获取当前进程的排名和全局进程数
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
# 获取序列长度和文本相关的长度
seq_len = q.shape[0]
if len(cu_seqlens_qkv) == 3:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
elif len(cu_seqlens_qkv) == 2:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = None
# 获取查询张量的头数和隐藏维度
_, heads, hidden_dims = q.shape
shard_heads = heads // world_size # 每个进程处理的头数
shard_seqlen = img_qkv_len # 每个进程处理的序列长度
# 分割图像和文本的查询、键和值
img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
# 将图像的查询、键和值转换为头的格式
img_q = all2all_seq2head(img_q)
img_k = all2all_seq2head(img_k)
img_v = all2all_seq2head(img_v)
torch.cuda.synchronize() # 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头
txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
# 合并图像和文本的查询、键和值
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
v = torch.cat((img_v, txt_v), dim=0)
# 初始化累积序列长度张量
cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device="cuda")
s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置
cu_seqlens_qkv[1] = s1 # 设置累积序列长度
if txt_mask_len:
s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置
cu_seqlens_qkv = torch.cat(cu_seqlens_qkv, s2)
max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度
# 调用注意力函数计算注意力结果
# attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
# 分割图像和文本的注意力结果
img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,]
# 收集所有进程的文本注意力结果
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
dist.all_gather(gathered_txt_attn, txt_attn)
# 处理图像注意力结果
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
img_attn = all2all_head2seq(img_attn) # 将头的格式转换回序列格式
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
torch.cuda.synchronize() # 确保CUDA操作完成
txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果
# 合并图像和文本的注意力结果
attn = torch.cat([img_attn, txt_attn], dim=0)
return attn # 返回最终的注意力结果
......@@ -8,7 +8,8 @@ import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from lightx2v.attentions import attention
# from lightx2v.attentions import attention
from lightx2v.common.ops.attn import TorchSDPAWeight
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8, Q8FQuantLinearInt8, Q8FQuantLinearFp8
......@@ -84,7 +85,7 @@ class SelfAttention(nn.Module):
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
# compute attention
x = attention(q=q, k=k, v=v, attention_type="torch_sdpa")
x = TorchSDPAWeight().apply(q=q, k=k, v=v)
x = x.reshape(b, s, c)
# output
......
......@@ -13,8 +13,6 @@ from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer im
HunyuanTransformerInferAdaCaching,
HunyuanTransformerInferCustomCaching,
)
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import *
from loguru import logger
from safetensors import safe_open
......@@ -41,14 +39,6 @@ class HunyuanModel:
self._init_weights()
self._init_infer()
if config["parallel_attn_type"]:
if config["parallel_attn_type"] == "ulysses":
ulysses_dist_wrap.parallelize_hunyuan(self)
elif config["parallel_attn_type"] == "ring":
ring_dist_wrap.parallelize_hunyuan(self)
else:
raise Exception(f"Unsuppotred parallel_attn_type")
if self.config["cpu_offload"]:
self.to_cpu()
......
......@@ -12,13 +12,8 @@ from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudi
from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from lightx2v.attentions.common.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer,
)
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
WanTransformerInferTeaCaching,
)
from safetensors import safe_open
from lightx2v.common.ops.attn.radial_attn import MaskMap
class WanAudioModel(WanModel):
......@@ -30,14 +25,9 @@ class WanAudioModel(WanModel):
super().__init__(model_path, config, device)
def _init_infer_class(self):
super()._init_infer_class()
self.pre_infer_class = WanAudioPreInfer
self.post_infer_class = WanAudioPostInfer
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
@torch.no_grad()
def infer(self, inputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment