Commit daf4c74e authored by helloyongyang's avatar helloyongyang Committed by Yang Yong(雍洋)
Browse files

first commit

parent 6c79160f
import torch
from .mm_weight import MMWeight
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
@MM_WEIGHT_REGISTER('Calib')
class MMWeightCalib(MMWeight):
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
def load(self, weight_dict):
assert self.config and self.config.get('mm_type', 'Default') != 'Default'
self.weight = weight_dict[self.weight_name]
self.get_quantizer()
shape_and_dtype = self.get_quant_shape_and_dtype(self.weight.shape)
self.realq_weight, self.scales, self.zeros = self.w_quantizer.real_quant_tensor(self.weight)
self.realq_weight = self.realq_weight.view(shape_and_dtype['tensor'][0]).contiguous().to(shape_and_dtype['tensor'][1])
self.scales = self.scales.view(shape_and_dtype['scales'][0]).contiguous().to(shape_and_dtype['scales'][1])
if self.zeros is not None:
self.zeros = self.zeros.view(shape_and_dtype['zeros'][0]).contiguous().to(shape_and_dtype['zeros'][1])
def apply(self, input_tensor):
return super().apply(input_tensor)
def get_quantizer(self):
if self.config['mm_type'] == 'W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm':
self.w_setting = {
'bit': 'e4m3',
'symmetric': True,
'granularity': 'channel'
}
self.a_setting = {
'bit': 'e4m3',
'symmetric': True,
'granularity': 'channel'
}
self.w_quantizer = FloatQuantizer(**self.w_setting)
self.a_quantizer = FloatQuantizer(**self.a_setting)
self.act_dynamic_quant = True
else:
raise NotImplementedError(f'Unsupported mm_type: {self.config["mm_type"]}')
def get_quant_shape_and_dtype(self, shape):
if self.config['mm_type'] == 'W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm':
return {
'tensor': (shape, torch.float8_e5m2),
'scales': ((shape[0], 1), torch.float32),
'zeros': None,
}
else:
raise NotImplementedError(f'Unsupported mm_type: {self.config["mm_type"]}')
from .rms_norm_weight import *
import torch
from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER
class LNWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, eps=1e-6):
self.weight_name = weight_name
self.bias_name = bias_name
self.eps = eps
self.config = {}
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].cuda() if self.weight_name is not None else None
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
def to_cpu(self):
if self.weight is not None:
self.weight = self.weight.cpu()
if self.bias is not None:
self.bias = self.bias.cpu()
def to_cuda(self):
if self.weight is not None:
self.weight = self.weight.cuda()
if self.bias is not None:
self.bias = self.bias.cuda()
@LN_WEIGHT_REGISTER('Default')
class LNWeight(LNWeightTemplate):
def __init__(self, weight_name, bias_name, eps=1e-6):
super().__init__(weight_name, bias_name, eps)
def apply(self, input_tensor):
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[1],), self.weight, self.bias, self.eps)
return input_tensor
import torch
from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER
import sgl_kernel
class RMSWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, eps=1e-6):
self.weight_name = weight_name
self.eps = eps
self.config = {}
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].cuda()
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
def to_cpu(self):
self.weight = self.weight.cpu()
def to_cuda(self):
self.weight = self.weight.cuda()
@RMS_WEIGHT_REGISTER('Default')
class RMSWeight(RMSWeightTemplate):
def __init__(self, weight_name, eps=1e-6):
super().__init__(weight_name, eps)
def apply(self, input_tensor):
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = input_tensor * self.weight
return input_tensor
@RMS_WEIGHT_REGISTER('FP32')
class RMSWeightFP32(RMSWeight):
def __init__(self, weight_name, eps=1e-6):
super().__init__(weight_name, eps)
def apply(self, input_tensor):
input_tensor = input_tensor.float()
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = input_tensor.to(torch.bfloat16)
input_tensor = input_tensor * self.weight
return input_tensor
@RMS_WEIGHT_REGISTER('sgl-kernel')
class RMSWeightSgl(RMSWeight):
def __init__(self, weight_name, eps=1e-6):
super().__init__(weight_name, eps)
def apply(self, input_tensor):
input_tensor = input_tensor.contiguous()
orig_shape = input_tensor.shape
input_tensor = input_tensor.view(-1, orig_shape[-1])
input_tensor = sgl_kernel.rmsnorm(input_tensor, self.weight, self.eps).view(orig_shape)
return input_tensor
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from lightx2v.attentions import attention
from lightx2v.text2v.models.text_encoders.hf.t5.tokenizer import HuggingfaceTokenizer
from .xlm_roberta import XLMRoberta
__all__ = [
'XLMRobertaCLIP',
'clip_xlm_roberta_vit_h_14',
'CLIPModel',
]
def pos_interpolate(pos, seq_len):
if pos.size(1) == seq_len:
return pos
else:
src_grid = int(math.sqrt(pos.size(1)))
tar_grid = int(math.sqrt(seq_len))
n = pos.size(1) - src_grid * src_grid
return torch.cat([
pos[:, :n],
F.interpolate(
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
0, 3, 1, 2),
size=(tar_grid, tar_grid),
mode='bicubic',
align_corners=False).flatten(2).transpose(1, 2)
],
dim=1)
class QuickGELU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
class SelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
causal=False,
attn_dropout=0.0,
proj_dropout=0.0):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.causal = causal
self.attn_dropout = attn_dropout
self.proj_dropout = proj_dropout
# layers
self.to_qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
# compute attention
x = attention(q=q, k=k, v=v, attention_type='torch_sdpa')
x = x.reshape(b, s, c)
# output
x = self.proj(x)
x = F.dropout(x, self.proj_dropout, self.training)
return x
class SwiGLU(nn.Module):
def __init__(self, dim, mid_dim):
super().__init__()
self.dim = dim
self.mid_dim = mid_dim
# layers
self.fc1 = nn.Linear(dim, mid_dim)
self.fc2 = nn.Linear(dim, mid_dim)
self.fc3 = nn.Linear(mid_dim, dim)
def forward(self, x):
x = F.silu(self.fc1(x)) * self.fc2(x)
x = self.fc3(x)
return x
class AttentionBlock(nn.Module):
def __init__(self,
dim,
mlp_ratio,
num_heads,
post_norm=False,
causal=False,
activation='quick_gelu',
attn_dropout=0.0,
proj_dropout=0.0,
norm_eps=1e-5):
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.post_norm = post_norm
self.causal = causal
self.norm_eps = norm_eps
# layers
self.norm1 = LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
proj_dropout)
self.norm2 = LayerNorm(dim, eps=norm_eps)
if activation == 'swi_glu':
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
else:
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x):
if self.post_norm:
x = x + self.norm1(self.attn(x))
x = x + self.norm2(self.mlp(x))
else:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class AttentionPool(nn.Module):
def __init__(self,
dim,
mlp_ratio,
num_heads,
activation='gelu',
proj_dropout=0.0,
norm_eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.proj_dropout = proj_dropout
self.norm_eps = norm_eps
# layers
gain = 1.0 / math.sqrt(dim)
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.norm = LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
# compute attention
x = attention(q=q, k=k, v=v, attention_type='torch_sdpa')
x = x.reshape(b, 1, c)
# output
x = self.proj(x)
x = F.dropout(x, self.proj_dropout, self.training)
# mlp
x = x + self.mlp(self.norm(x))
return x[:, 0]
class VisionTransformer(nn.Module):
def __init__(self,
image_size=224,
patch_size=16,
dim=768,
mlp_ratio=4,
out_dim=512,
num_heads=12,
num_layers=12,
pool_type='token',
pre_norm=True,
post_norm=False,
activation='quick_gelu',
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5):
if image_size % patch_size != 0:
print(
'[WARNING] image_size is not divisible by patch_size',
flush=True)
assert pool_type in ('token', 'token_fc', 'attn_pool')
out_dim = out_dim or dim
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size)**2
self.dim = dim
self.mlp_ratio = mlp_ratio
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.pool_type = pool_type
self.post_norm = post_norm
self.norm_eps = norm_eps
# embeddings
gain = 1.0 / math.sqrt(dim)
self.patch_embedding = nn.Conv2d(
3,
dim,
kernel_size=patch_size,
stride=patch_size,
bias=not pre_norm)
if pool_type in ('token', 'token_fc'):
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(gain * torch.randn(
1, self.num_patches +
(1 if pool_type in ('token', 'token_fc') else 0), dim))
self.dropout = nn.Dropout(embedding_dropout)
# transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.transformer = nn.Sequential(*[
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
activation, attn_dropout, proj_dropout, norm_eps)
for _ in range(num_layers)
])
self.post_norm = LayerNorm(dim, eps=norm_eps)
# head
if pool_type == 'token':
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
elif pool_type == 'token_fc':
self.head = nn.Linear(dim, out_dim)
elif pool_type == 'attn_pool':
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
proj_dropout, norm_eps)
def forward(self, x, interpolation=False, use_31_block=False):
b = x.size(0)
# embeddings
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
if self.pool_type in ('token', 'token_fc'):
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
if interpolation:
e = pos_interpolate(self.pos_embedding, x.size(1))
else:
e = self.pos_embedding
x = self.dropout(x + e)
if self.pre_norm is not None:
x = self.pre_norm(x)
# transformer
if use_31_block:
x = self.transformer[:-1](x)
return x
else:
x = self.transformer(x)
return x
class XLMRobertaWithHead(XLMRoberta):
def __init__(self, **kwargs):
self.out_dim = kwargs.pop('out_dim')
super().__init__(**kwargs)
# head
mid_dim = (self.dim + self.out_dim) // 2
self.head = nn.Sequential(
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
nn.Linear(mid_dim, self.out_dim, bias=False))
def forward(self, ids):
# xlm-roberta
x = super().forward(ids)
# average pooling
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
# head
x = self.head(x)
return x
class XLMRobertaCLIP(nn.Module):
def __init__(self,
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool='token',
vision_pre_norm=True,
vision_post_norm=False,
activation='gelu',
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5):
super().__init__()
self.embed_dim = embed_dim
self.image_size = image_size
self.patch_size = patch_size
self.vision_dim = vision_dim
self.vision_mlp_ratio = vision_mlp_ratio
self.vision_heads = vision_heads
self.vision_layers = vision_layers
self.vision_pre_norm = vision_pre_norm
self.vision_post_norm = vision_post_norm
self.activation = activation
self.vocab_size = vocab_size
self.max_text_len = max_text_len
self.type_size = type_size
self.pad_id = pad_id
self.text_dim = text_dim
self.text_heads = text_heads
self.text_layers = text_layers
self.text_post_norm = text_post_norm
self.norm_eps = norm_eps
# models
self.visual = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
mlp_ratio=vision_mlp_ratio,
out_dim=embed_dim,
num_heads=vision_heads,
num_layers=vision_layers,
pool_type=vision_pool,
pre_norm=vision_pre_norm,
post_norm=vision_post_norm,
activation=activation,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout,
norm_eps=norm_eps)
self.textual = XLMRobertaWithHead(
vocab_size=vocab_size,
max_seq_len=max_text_len,
type_size=type_size,
pad_id=pad_id,
dim=text_dim,
out_dim=embed_dim,
num_heads=text_heads,
num_layers=text_layers,
post_norm=text_post_norm,
dropout=text_dropout)
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
def forward(self, imgs, txt_ids):
"""
imgs: [B, 3, H, W] of torch.float32.
- mean: [0.48145466, 0.4578275, 0.40821073]
- std: [0.26862954, 0.26130258, 0.27577711]
txt_ids: [B, L] of torch.long.
Encoded by data.CLIPTokenizer.
"""
xi = self.visual(imgs)
xt = self.textual(txt_ids)
return xi, xt
def param_groups(self):
groups = [{
'params': [
p for n, p in self.named_parameters()
if 'norm' in n or n.endswith('bias')
],
'weight_decay': 0.0
}, {
'params': [
p for n, p in self.named_parameters()
if not ('norm' in n or n.endswith('bias'))
]
}]
return groups
def _clip(pretrained=False,
pretrained_name=None,
model_cls=XLMRobertaCLIP,
return_transforms=False,
return_tokenizer=False,
tokenizer_padding='eos',
dtype=torch.float32,
device='cpu',
**kwargs):
# init a model on device
with torch.device(device):
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
output = (model,)
# init transforms
if return_transforms:
# mean and std
if 'siglip' in pretrained_name.lower():
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
else:
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
# transforms
transforms = T.Compose([
T.Resize((model.image_size, model.image_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=mean, std=std)
])
output += (transforms,)
return output[0] if len(output) == 1 else output
def clip_xlm_roberta_vit_h_14(
pretrained=False,
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
**kwargs):
cfg = dict(
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool='token',
activation='gelu',
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0)
cfg.update(**kwargs)
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
self.dtype = dtype
self.device = device
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
pretrained=False,
return_transforms=True,
return_tokenizer=False,
dtype=dtype,
device=device)
self.model = self.model.eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}')
self.model.load_state_dict(
torch.load(checkpoint_path, map_location='cpu', weights_only=True))
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path,
seq_len=self.model.max_text_len - 2,
clean='whitespace')
def visual(self, videos):
# preprocess
size = (self.model.image_size,) * 2
videos = torch.cat([
F.interpolate(
u.transpose(0, 1),
size=size,
mode='bicubic',
align_corners=False) for u in videos
])
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
with torch.amp.autocast('cuda', dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True)
return out
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['XLMRoberta', 'xlm_roberta_large']
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
# compute attention
p = self.dropout.p if self.training else 0.0
x = F.scaled_dot_product_attention(q, k, v, mask, p)
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
# output
x = self.o(x)
x = self.dropout(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.post_norm = post_norm
self.eps = eps
# layers
self.attn = SelfAttention(dim, num_heads, dropout, eps)
self.norm1 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
nn.Dropout(dropout))
self.norm2 = nn.LayerNorm(dim, eps=eps)
def forward(self, x, mask):
if self.post_norm:
x = self.norm1(x + self.attn(x, mask))
x = self.norm2(x + self.ffn(x))
else:
x = x + self.attn(self.norm1(x), mask)
x = x + self.ffn(self.norm2(x))
return x
class XLMRoberta(nn.Module):
"""
XLMRobertaModel with no pooler and no LM head.
"""
def __init__(self,
vocab_size=250002,
max_seq_len=514,
type_size=1,
pad_id=1,
dim=1024,
num_heads=16,
num_layers=24,
post_norm=True,
dropout=0.1,
eps=1e-5):
super().__init__()
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.type_size = type_size
self.pad_id = pad_id
self.dim = dim
self.num_heads = num_heads
self.num_layers = num_layers
self.post_norm = post_norm
self.eps = eps
# embeddings
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
self.type_embedding = nn.Embedding(type_size, dim)
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
self.dropout = nn.Dropout(dropout)
# blocks
self.blocks = nn.ModuleList([
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
for _ in range(num_layers)
])
# norm layer
self.norm = nn.LayerNorm(dim, eps=eps)
def forward(self, ids):
"""
ids: [B, L] of torch.LongTensor.
"""
b, s = ids.shape
mask = ids.ne(self.pad_id).long()
# embeddings
x = self.token_embedding(ids) + \
self.type_embedding(torch.zeros_like(ids)) + \
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
if self.post_norm:
x = self.norm(x)
x = self.dropout(x)
# blocks
mask = torch.where(
mask.view(b, 1, 1, s).gt(0), 0.0,
torch.finfo(x.dtype).min)
for block in self.blocks:
x = block(x, mask)
# output
if not self.post_norm:
x = self.norm(x)
return x
def xlm_roberta_large(pretrained=False,
return_tokenizer=False,
device='cpu',
**kwargs):
"""
XLMRobertaLarge adapted from Huggingface.
"""
# params
cfg = dict(
vocab_size=250002,
max_seq_len=514,
type_size=1,
pad_id=1,
dim=1024,
num_heads=16,
num_layers=24,
post_norm=True,
dropout=0.1,
eps=1e-5)
cfg.update(**kwargs)
# init a model on device
with torch.device(device):
model = XLMRoberta(**cfg)
return model
import torch
from einops import rearrange
from lightx2v.attentions import attention
from ..utils_bf16 import apply_rotary_emb
from typing import Dict
import math
from ..transformer_infer import HunyuanTransformerInfer
def taylor_cache_init(cache_dic: Dict, current: Dict):
"""
Initialize Taylor cache, expanding storage areas for Taylor series derivatives
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
if current['step'] == 0:
cache_dic['cache'][-1][current['stream']][current['layer']][current['module']] = {}
def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor):
"""
Compute derivative approximation
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
difference_distance = current['activated_steps'][-1] - current['activated_steps'][-2]
#difference_distance = current['activated_times'][-1] - current['activated_times'][-2]
updated_taylor_factors = {}
updated_taylor_factors[0] = feature
for i in range(cache_dic['max_order']):
if (cache_dic['cache'][-1][current['stream']][current['layer']][current['module']].get(i, None) is not None) and (current['step'] > cache_dic['first_enhance'] - 2):
updated_taylor_factors[i + 1] = (updated_taylor_factors[i] - cache_dic['cache'][-1][current['stream']][current['layer']][current['module']][i]) / difference_distance
else:
break
cache_dic['cache'][-1][current['stream']][current['layer']][current['module']] = updated_taylor_factors
def taylor_formula(cache_dic: Dict, current: Dict) -> torch.Tensor:
"""
Compute Taylor expansion error
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
x = current['step'] - current['activated_steps'][-1]
#x = current['t'] - current['activated_times'][-1]
output = 0
for i in range(len(cache_dic['cache'][-1][current['stream']][current['layer']][current['module']])):
output += (1 / math.factorial(i)) * cache_dic['cache'][-1][current['stream']][current['layer']][current['module']][i] * (x ** i)
return output
class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
def __init__(self, config):
super().__init__(config)
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
self.scheduler.current['stream'] = 'double_stream'
for i in range(self.double_blocks_num):
self.scheduler.current['layer'] = i
img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
x = torch.cat((img, txt), 0)
self.scheduler.current['stream'] = 'single_stream'
for i in range(self.single_blocks_num):
self.scheduler.current['layer'] = i
x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
img = x[:img_seq_len, ...]
return img, vec
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
vec_silu = torch.nn.functional.silu(vec)
img_mod_out = weights.img_mod.apply(vec_silu)
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = img_mod_out.chunk(6, dim=-1)
txt_mod_out = weights.txt_mod.apply(vec_silu)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = txt_mod_out.chunk(6, dim=-1)
if self.scheduler.current['type'] == 'full':
img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, freqs_cis)
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift)
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
v = torch.cat((img_v, txt_v), dim=0)
if not self.parallel_attention:
attn = attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv,
)
else:
# world_size = dist.get_world_size()
attn = self.parallel_attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
img_qkv_len=img_q.shape[0],
cu_seqlens_qkv=cu_seqlens_qkv
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :]
img = self.infer_double_block_img_post_atten(weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate)
txt = self.infer_double_block_txt_post_atten(weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate)
return img, txt
elif self.scheduler.current['type'] == 'taylor_cache':
self.scheduler.current['module'] = 'img_attn'
out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current)
out = out * img_mod1_gate
img = img + out
self.scheduler.current['module'] = 'img_mlp'
out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current)
out = out * img_mod2_gate
img = img + out
self.scheduler.current['module'] = 'txt_attn'
out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current)
out = out * txt_mod1_gate
txt = txt + out
self.scheduler.current['module'] = 'txt_mlp'
out = out * txt_mod2_gate
txt = txt + out
return img, txt
def infer_double_block_img_post_atten(self, weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate):
self.scheduler.current['module'] = 'img_attn'
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
out = weights.img_attn_proj.apply(img_attn)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out)
out = out * img_mod1_gate
img = img + out
self.scheduler.current['module'] = 'img_mlp'
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
out = out * (1 + img_mod2_scale) + img_mod2_shift
out = weights.img_mlp_fc1.apply(out)
out = torch.nn.functional.gelu(out, approximate='tanh')
out = weights.img_mlp_fc2.apply(out)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out)
out = out * img_mod2_gate
img = img + out
return img
def infer_double_block_txt_post_atten(self, weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate):
self.scheduler.current['module'] = 'txt_attn'
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
out = weights.txt_attn_proj.apply(txt_attn)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out)
out = out * txt_mod1_gate
txt = txt + out
self.scheduler.current['module'] = 'txt_mlp'
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
out = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6)
out = out * (1 + txt_mod2_scale) + txt_mod2_shift
out = weights.txt_mlp_fc1.apply(out)
out = torch.nn.functional.gelu(out, approximate='tanh')
out = weights.txt_mlp_fc2.apply(out)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out)
out = out * txt_mod2_gate
txt = txt + out
return txt
def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
out = torch.nn.functional.silu(vec)
out = weights.modulation.apply(out)
mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1)
if self.scheduler.current['type'] == 'full':
out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
x_mod = out * (1 + mod_scale) + mod_shift
x_mod = weights.linear1.apply(x_mod)
qkv, mlp = torch.split(x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
self.scheduler.current['module'] = 'attn'
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
q, k, v = rearrange(qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
q = weights.q_norm.apply(q)
k = weights.k_norm.apply(k)
img_q, txt_q = q[:-txt_seq_len, :, :], q[-txt_seq_len:, :, :]
img_k, txt_k = k[:-txt_seq_len, :, :], k[-txt_seq_len:, :, :]
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis)
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
if not self.parallel_attention:
attn = attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv,
)
else:
attn = self.parallel_attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
img_qkv_len=img_q.shape[0],
cu_seqlens_qkv=cu_seqlens_qkv
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, attn)
self.scheduler.current['module'] = 'total'
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
out = torch.nn.functional.gelu(mlp, approximate='tanh')
out = torch.cat((attn, out), 1)
out = weights.linear2.apply(out)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out)
out = out * mod_gate
x = x + out
return x
elif self.scheduler.current['type'] == 'taylor_cache':
self.scheduler.current['module'] = 'total'
out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current)
out = out * mod_gate
x = x + out
return x
import torch
class HunyuanPostInfer():
def __init__(self):
pass
def infer(self, weights, img, vec, shape):
out = torch.nn.functional.silu(vec)
out = weights.final_layer_adaLN_modulation_1.apply(out)
shift, scale = out.chunk(2, dim=1)
out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
out = out * (1 + scale) + shift
out = weights.final_layer_linear.apply(out.to(torch.float32))
_, _, ot, oh, ow = shape
patch_size = [1, 2, 2]
tt, th, tw = (
ot // patch_size[0],
oh // patch_size[1],
ow // patch_size[2],
)
c = 16
pt, ph, pw = patch_size
out = out.reshape(shape=(1, tt, th, tw, c, pt, ph, pw))
out = torch.einsum("nthwcopq->nctohpwq", out)
out = out.reshape(shape=(1, c, tt * pt, th * ph, tw * pw))
return out
import torch
import math
from einops import rearrange
from lightx2v.attentions import attention
class HunyuanPreInfer():
def __init__(self):
self.heads_num = 24
def infer(self, weights, x, t, text_states, text_mask, text_states_2, freqs_cos, freqs_sin, guidance):
time_out = self.infer_time_in(weights, t)
img_out = self.infer_img_in(weights, x)
infer_text_out = self.infer_text_in(weights, text_states, text_mask, t)
infer_vector_out = self.infer_vector_in(weights, text_states_2)
vec = time_out + infer_vector_out
guidance_out = self.infer_guidance_in(weights, guidance)
vec = vec + guidance_out
txt_seq_len = infer_text_out.shape[0]
img_seq_len = img_out.shape[1]
batch_size = text_mask.shape[0]
text_len = text_mask.sum(dim=1)
max_len = text_mask.shape[1] + img_seq_len
cu_seqlens_qkv = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
for i in range(batch_size):
s = text_len[i] + img_seq_len
s1 = i * max_len + s
s2 = (i + 1) * max_len
cu_seqlens_qkv[2 * i + 1] = s1
cu_seqlens_qkv[2 * i + 2] = s2
max_seqlen_qkv = img_seq_len + txt_seq_len
return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin)
def infer_time_in(self, weights, t):
freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=t.device)
args = t.unsqueeze(0).unsqueeze(0).float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=torch.bfloat16)
out = weights.time_in_mlp_0.apply(embedding)
out = torch.nn.functional.silu(out)
out = weights.time_in_mlp_2.apply(out)
return out
def infer_img_in(self, weights, x):
out = weights.img_in_proj.apply(x)
out = out.flatten(2).transpose(1, 2)
return out
def infer_text_in(self, weights, text_states, text_mask, t):
freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=t.device)
args = t.unsqueeze(0).unsqueeze(0).float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=torch.bfloat16)
out = weights.txt_in_t_embedder_mlp_0.apply(embedding)
out = torch.nn.functional.silu(out)
timestep_aware_representations = weights.txt_in_t_embedder_mlp_2.apply(out)
mask_float = text_mask.float().unsqueeze(-1).to(torch.bfloat16) # [b, s1, 1]
context_aware_representations = (text_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = context_aware_representations
out = weights.txt_in_c_embedder_linear_1.apply(context_aware_representations)
out = torch.nn.functional.silu(out)
context_aware_representations = weights.txt_in_c_embedder_linear_2.apply(out)
c = timestep_aware_representations + context_aware_representations
txt_in_input_embed = weights.txt_in_input_embedder.apply(text_states[0])
batch_size = text_mask.shape[0]
seq_len = text_mask.shape[1]
self_attn_mask_1 = text_mask.view(batch_size, 1, 1, seq_len).repeat(
1, 1, seq_len, 1
)
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
self_attn_mask[:, :, :, 0] = True
cx = torch.nn.functional.silu(c)
cx = weights.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1.apply(cx)
gate_msa, gate_mlp = cx.chunk(2, dim=1)
normx = weights.txt_in_individual_token_refiner_blocks_0_norm1.apply(txt_in_input_embed)
qkv = weights.txt_in_individual_token_refiner_blocks_0_self_attn_qkv.apply(normx)
q, k, v = rearrange(qkv.unsqueeze(0), "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
attn = attention(attention_type="torch_sdpa", q=q, k=k, v=v, attn_mask=self_attn_mask)[0]
out = weights.txt_in_individual_token_refiner_blocks_0_self_attn_proj.apply(attn)
out_1 = txt_in_input_embed + out * gate_msa
out = weights.txt_in_individual_token_refiner_blocks_0_norm2.apply(out_1)
# mlp
out = weights.txt_in_individual_token_refiner_blocks_0_mlp_fc1.apply(out)
out = torch.nn.functional.silu(out)
out = weights.txt_in_individual_token_refiner_blocks_0_mlp_fc2.apply(out)
txt_in_input_embed = out_1 + out * gate_mlp
cx = torch.nn.functional.silu(c)
cx = weights.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1.apply(cx)
gate_msa, gate_mlp = cx.chunk(2, dim=1)
normx = weights.txt_in_individual_token_refiner_blocks_1_norm1.apply(txt_in_input_embed)
qkv = weights.txt_in_individual_token_refiner_blocks_1_self_attn_qkv.apply(normx)
q, k, v = rearrange(qkv.unsqueeze(0), "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
attn = attention(attention_type="torch_sdpa", q=q, k=k, v=v, attn_mask=self_attn_mask)[0]
out = weights.txt_in_individual_token_refiner_blocks_1_self_attn_proj.apply(attn)
out_1 = txt_in_input_embed + out * gate_msa
out = weights.txt_in_individual_token_refiner_blocks_1_norm2.apply(out_1)
# mlp
out = weights.txt_in_individual_token_refiner_blocks_1_mlp_fc1.apply(out)
out = torch.nn.functional.silu(out)
out = weights.txt_in_individual_token_refiner_blocks_1_mlp_fc2.apply(out)
out = out_1 + out * gate_mlp
return out
def infer_vector_in(self, weights, text_states_2):
out = weights.vector_in_in_layer.apply(text_states_2)
out = torch.nn.functional.silu(out)
out = weights.vector_in_out_layer.apply(out)
return out
def infer_guidance_in(self, weights, guidance):
freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=guidance.device)
args = guidance.float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=torch.bfloat16)
out = weights.guidance_in_mlp_0.apply(embedding)
out = torch.nn.functional.silu(out)
out = weights.guidance_in_mlp_2.apply(out)
return out
import torch
from einops import rearrange
from lightx2v.attentions import attention
from .utils_bf16 import apply_rotary_emb
class HunyuanTransformerInfer():
def __init__(self, config):
self.config = config
self.attention_type = config.get("attention_type", "flash_attn2")
self.double_blocks_num = 20
self.single_blocks_num = 40
self.heads_num = 24
self.hidden_size = 3072
self.mlp_hidden_dim = 12288
self.parallel_attention = None
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
for i in range(self.double_blocks_num):
img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
x = torch.cat((img, txt), 0)
for i in range(self.single_blocks_num):
x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
img = x[:img_seq_len, ...]
return img, vec
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
vec_silu = torch.nn.functional.silu(vec)
img_mod_out = weights.img_mod.apply(vec_silu)
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = img_mod_out.chunk(6, dim=-1)
txt_mod_out = weights.txt_mod.apply(vec_silu)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = txt_mod_out.chunk(6, dim=-1)
img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, freqs_cis)
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift)
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
v = torch.cat((img_v, txt_v), dim=0)
if not self.parallel_attention:
attn = attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv,
)
else:
# world_size = dist.get_world_size()
attn = self.parallel_attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
img_qkv_len=img_q.shape[0],
cu_seqlens_qkv=cu_seqlens_qkv
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :]
img = self.infer_double_block_img_post_atten(weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate)
txt = self.infer_double_block_txt_post_atten(weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate)
return img, txt
def infer_double_block_img_pre_atten(self, weights, img, img_mod1_scale, img_mod1_shift, freqs_cis):
img_modulated = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
img_modulated = img_modulated * (1 + img_mod1_scale) + img_mod1_shift
img_qkv = weights.img_attn_qkv.apply(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num
)
img_q = weights.img_attn_q_norm.apply(img_q)
img_k = weights.img_attn_k_norm.apply(img_k)
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis)
return img_q, img_k, img_v
def infer_double_block_txt_pre_atten(self, weights, txt, txt_mod1_scale, txt_mod1_shift):
txt_modulated = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6)
txt_modulated = txt_modulated * (1 + txt_mod1_scale) + txt_mod1_shift
txt_qkv = weights.txt_attn_qkv.apply(txt_modulated)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num
)
txt_q = weights.txt_attn_q_norm.apply(txt_q)
txt_k = weights.txt_attn_k_norm.apply(txt_k)
return txt_q, txt_k, txt_v
def infer_double_block_img_post_atten(self, weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate):
out = weights.img_attn_proj.apply(img_attn)
out = out * img_mod1_gate
img = img + out
out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
out = out * (1 + img_mod2_scale) + img_mod2_shift
out = weights.img_mlp_fc1.apply(out)
out = torch.nn.functional.gelu(out, approximate='tanh')
out = weights.img_mlp_fc2.apply(out)
out = out * img_mod2_gate
img = img + out
return img
def infer_double_block_txt_post_atten(self, weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate):
out = weights.txt_attn_proj.apply(txt_attn)
out = out * txt_mod1_gate
txt = txt + out
out = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6)
out = out * (1 + txt_mod2_scale) + txt_mod2_shift
out = weights.txt_mlp_fc1.apply(out)
out = torch.nn.functional.gelu(out, approximate='tanh')
out = weights.txt_mlp_fc2.apply(out)
out = out * txt_mod2_gate
txt = txt + out
return txt
def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
out = torch.nn.functional.silu(vec)
out = weights.modulation.apply(out)
mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1)
out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
x_mod = out * (1 + mod_scale) + mod_shift
x_mod = weights.linear1.apply(x_mod)
qkv, mlp = torch.split(x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
q = weights.q_norm.apply(q)
k = weights.k_norm.apply(k)
img_q, txt_q = q[:-txt_seq_len, :, :], q[-txt_seq_len:, :, :]
img_k, txt_k = k[:-txt_seq_len, :, :], k[-txt_seq_len:, :, :]
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis)
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
if not self.parallel_attention:
attn = attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv,
)
else:
attn = self.parallel_attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
img_qkv_len=img_q.shape[0],
cu_seqlens_qkv=cu_seqlens_qkv
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
out = torch.nn.functional.gelu(mlp, approximate='tanh')
out = torch.cat((attn, out), 1)
out = weights.linear2.apply(out)
out = out * mod_gate
x = x + out
return x
import sgl_kernel
def rms_norm(x, weight, eps):
x = x.contiguous()
orig_shape = x.shape
x = x.view(-1, orig_shape[-1])
x = sgl_kernel.rmsnorm(x, weight, eps).view(orig_shape)
return x
import torch
from typing import Any, List, Tuple, Optional, Union, Dict
def rms_norm(x, weight, eps):
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
x = x * weight
return x
def rotate_half(x, shape_0, shape_1):
x_real, x_imag = x.reshape(shape_0, shape_1, -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(2)
def rotary_emb(x, shape_0, shape_1, cos, sin):
x_out = (x * cos + rotate_half(x, shape_0, shape_1) * sin)
return x_out
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
shape_0, shape_1, shape_2 = xq.shape
cos = freqs_cis[0].view(shape_0, 1, shape_2)
sin = freqs_cis[1].view(shape_0, 1, shape_2)
xq_out = rotary_emb(xq, shape_0, shape_1, cos, sin)
xk_out = rotary_emb(xk, shape_0, shape_1, cos, sin)
return xq_out, xk_out
import torch
from typing import Any, List, Tuple, Optional, Union, Dict
def rms_norm(x, weight, eps):
x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
x = x.to(torch.bfloat16)
x = x * weight
return x
def rotate_half(x, shape_0, shape_1):
x_real, x_imag = x.float().reshape(shape_0, shape_1, -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(2)
def rotary_emb(x, shape_0, shape_1, cos, sin):
x_out = (x * cos + rotate_half(x, shape_0, shape_1) * sin)
return x_out.to(torch.bfloat16)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
shape_0, shape_1, shape_2 = xq.shape
cos = freqs_cis[0].view(shape_0, 1, shape_2)
sin = freqs_cis[1].view(shape_0, 1, shape_2)
xq_out = rotary_emb(xq.float(), shape_0, shape_1, cos, sin)
xk_out = rotary_emb(xk.float(), shape_0, shape_1, cos, sin)
return xq_out, xk_out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment