Commit 1336a33d authored by zzg_666's avatar zzg_666
Browse files

wan2.2

parents
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import numbers
from peft import LoraConfig
def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"):
target_modules = []
for name, module in transformer.named_modules():
if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear):
target_modules.append(name)
transformer_lora_config = LoraConfig(
r=rank,
lora_alpha=alpha,
init_lora_weights=init_lora_weights,
target_modules=target_modules,
)
return transformer_lora_config
class TensorList(object):
def __init__(self, tensors):
"""
tensors: a list of torch.Tensor objects. No need to have uniform shape.
"""
assert isinstance(tensors, (list, tuple))
assert all(isinstance(u, torch.Tensor) for u in tensors)
assert len(set([u.ndim for u in tensors])) == 1
assert len(set([u.dtype for u in tensors])) == 1
assert len(set([u.device for u in tensors])) == 1
self.tensors = tensors
def to(self, *args, **kwargs):
return TensorList([u.to(*args, **kwargs) for u in self.tensors])
def size(self, dim):
assert dim == 0, 'only support get the 0th size'
return len(self.tensors)
def pow(self, *args, **kwargs):
return TensorList([u.pow(*args, **kwargs) for u in self.tensors])
def squeeze(self, dim):
assert dim != 0
if dim > 0:
dim -= 1
return TensorList([u.squeeze(dim) for u in self.tensors])
def type(self, *args, **kwargs):
return TensorList([u.type(*args, **kwargs) for u in self.tensors])
def type_as(self, other):
assert isinstance(other, (torch.Tensor, TensorList))
if isinstance(other, torch.Tensor):
return TensorList([u.type_as(other) for u in self.tensors])
else:
return TensorList([u.type(other.dtype) for u in self.tensors])
@property
def dtype(self):
return self.tensors[0].dtype
@property
def device(self):
return self.tensors[0].device
@property
def ndim(self):
return 1 + self.tensors[0].ndim
def __getitem__(self, index):
return self.tensors[index]
def __len__(self):
return len(self.tensors)
def __add__(self, other):
return self._apply(other, lambda u, v: u + v)
def __radd__(self, other):
return self._apply(other, lambda u, v: v + u)
def __sub__(self, other):
return self._apply(other, lambda u, v: u - v)
def __rsub__(self, other):
return self._apply(other, lambda u, v: v - u)
def __mul__(self, other):
return self._apply(other, lambda u, v: u * v)
def __rmul__(self, other):
return self._apply(other, lambda u, v: v * u)
def __floordiv__(self, other):
return self._apply(other, lambda u, v: u // v)
def __truediv__(self, other):
return self._apply(other, lambda u, v: u / v)
def __rfloordiv__(self, other):
return self._apply(other, lambda u, v: v // u)
def __rtruediv__(self, other):
return self._apply(other, lambda u, v: v / u)
def __pow__(self, other):
return self._apply(other, lambda u, v: u ** v)
def __rpow__(self, other):
return self._apply(other, lambda u, v: v ** u)
def __neg__(self):
return TensorList([-u for u in self.tensors])
def __iter__(self):
for tensor in self.tensors:
yield tensor
def __repr__(self):
return 'TensorList: \n' + repr(self.tensors)
def _apply(self, other, op):
if isinstance(other, (list, tuple, TensorList)) or (
isinstance(other, torch.Tensor) and (
other.numel() > 1 or other.ndim > 1
)
):
assert len(other) == len(self.tensors)
return TensorList([op(u, v) for u, v in zip(self.tensors, other)])
elif isinstance(other, numbers.Number) or (
isinstance(other, torch.Tensor) and (
other.numel() == 1 and other.ndim <= 1
)
):
return TensorList([op(u, other) for u in self.tensors])
else:
raise TypeError(
f'unsupported operand for *: "TensorList" and "{type(other)}"'
)
\ No newline at end of file
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from ..attention import flash_attention
from ..tokenizers import HuggingfaceTokenizer
from .xlm_roberta import XLMRoberta
__all__ = [
'XLMRobertaCLIP',
'clip_xlm_roberta_vit_h_14',
'CLIPModel',
]
def pos_interpolate(pos, seq_len):
if pos.size(1) == seq_len:
return pos
else:
src_grid = int(math.sqrt(pos.size(1)))
tar_grid = int(math.sqrt(seq_len))
n = pos.size(1) - src_grid * src_grid
return torch.cat([
pos[:, :n],
F.interpolate(
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
0, 3, 1, 2),
size=(tar_grid, tar_grid),
mode='bicubic',
align_corners=False).flatten(2).transpose(1, 2)
],
dim=1)
class QuickGELU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
class SelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
causal=False,
attn_dropout=0.0,
proj_dropout=0.0):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.causal = causal
self.attn_dropout = attn_dropout
self.proj_dropout = proj_dropout
# layers
self.to_qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
# compute attention
p = self.attn_dropout if self.training else 0.0
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
x = x.reshape(b, s, c)
# output
x = self.proj(x)
x = F.dropout(x, self.proj_dropout, self.training)
return x
class SwiGLU(nn.Module):
def __init__(self, dim, mid_dim):
super().__init__()
self.dim = dim
self.mid_dim = mid_dim
# layers
self.fc1 = nn.Linear(dim, mid_dim)
self.fc2 = nn.Linear(dim, mid_dim)
self.fc3 = nn.Linear(mid_dim, dim)
def forward(self, x):
x = F.silu(self.fc1(x)) * self.fc2(x)
x = self.fc3(x)
return x
class AttentionBlock(nn.Module):
def __init__(self,
dim,
mlp_ratio,
num_heads,
post_norm=False,
causal=False,
activation='quick_gelu',
attn_dropout=0.0,
proj_dropout=0.0,
norm_eps=1e-5):
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.post_norm = post_norm
self.causal = causal
self.norm_eps = norm_eps
# layers
self.norm1 = LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
proj_dropout)
self.norm2 = LayerNorm(dim, eps=norm_eps)
if activation == 'swi_glu':
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
else:
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x):
if self.post_norm:
x = x + self.norm1(self.attn(x))
x = x + self.norm2(self.mlp(x))
else:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class AttentionPool(nn.Module):
def __init__(self,
dim,
mlp_ratio,
num_heads,
activation='gelu',
proj_dropout=0.0,
norm_eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.proj_dropout = proj_dropout
self.norm_eps = norm_eps
# layers
gain = 1.0 / math.sqrt(dim)
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.norm = LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
# compute attention
x = flash_attention(q, k, v, version=2)
x = x.reshape(b, 1, c)
# output
x = self.proj(x)
x = F.dropout(x, self.proj_dropout, self.training)
# mlp
x = x + self.mlp(self.norm(x))
return x[:, 0]
class VisionTransformer(nn.Module):
def __init__(self,
image_size=224,
patch_size=16,
dim=768,
mlp_ratio=4,
out_dim=512,
num_heads=12,
num_layers=12,
pool_type='token',
pre_norm=True,
post_norm=False,
activation='quick_gelu',
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5):
if image_size % patch_size != 0:
print(
'[WARNING] image_size is not divisible by patch_size',
flush=True)
assert pool_type in ('token', 'token_fc', 'attn_pool')
out_dim = out_dim or dim
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size)**2
self.dim = dim
self.mlp_ratio = mlp_ratio
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.pool_type = pool_type
self.post_norm = post_norm
self.norm_eps = norm_eps
# embeddings
gain = 1.0 / math.sqrt(dim)
self.patch_embedding = nn.Conv2d(
3,
dim,
kernel_size=patch_size,
stride=patch_size,
bias=not pre_norm)
if pool_type in ('token', 'token_fc'):
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(gain * torch.randn(
1, self.num_patches +
(1 if pool_type in ('token', 'token_fc') else 0), dim))
self.dropout = nn.Dropout(embedding_dropout)
# transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.transformer = nn.Sequential(*[
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
activation, attn_dropout, proj_dropout, norm_eps)
for _ in range(num_layers)
])
self.post_norm = LayerNorm(dim, eps=norm_eps)
# head
if pool_type == 'token':
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
elif pool_type == 'token_fc':
self.head = nn.Linear(dim, out_dim)
elif pool_type == 'attn_pool':
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
proj_dropout, norm_eps)
def forward(self, x, interpolation=False, use_31_block=False):
b = x.size(0)
# embeddings
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
if self.pool_type in ('token', 'token_fc'):
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
if interpolation:
e = pos_interpolate(self.pos_embedding, x.size(1))
else:
e = self.pos_embedding
x = self.dropout(x + e)
if self.pre_norm is not None:
x = self.pre_norm(x)
# transformer
if use_31_block:
x = self.transformer[:-1](x)
return x
else:
x = self.transformer(x)
return x
class XLMRobertaWithHead(XLMRoberta):
def __init__(self, **kwargs):
self.out_dim = kwargs.pop('out_dim')
super().__init__(**kwargs)
# head
mid_dim = (self.dim + self.out_dim) // 2
self.head = nn.Sequential(
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
nn.Linear(mid_dim, self.out_dim, bias=False))
def forward(self, ids):
# xlm-roberta
x = super().forward(ids)
# average pooling
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
# head
x = self.head(x)
return x
class XLMRobertaCLIP(nn.Module):
def __init__(self,
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool='token',
vision_pre_norm=True,
vision_post_norm=False,
activation='gelu',
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5):
super().__init__()
self.embed_dim = embed_dim
self.image_size = image_size
self.patch_size = patch_size
self.vision_dim = vision_dim
self.vision_mlp_ratio = vision_mlp_ratio
self.vision_heads = vision_heads
self.vision_layers = vision_layers
self.vision_pre_norm = vision_pre_norm
self.vision_post_norm = vision_post_norm
self.activation = activation
self.vocab_size = vocab_size
self.max_text_len = max_text_len
self.type_size = type_size
self.pad_id = pad_id
self.text_dim = text_dim
self.text_heads = text_heads
self.text_layers = text_layers
self.text_post_norm = text_post_norm
self.norm_eps = norm_eps
# models
self.visual = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
mlp_ratio=vision_mlp_ratio,
out_dim=embed_dim,
num_heads=vision_heads,
num_layers=vision_layers,
pool_type=vision_pool,
pre_norm=vision_pre_norm,
post_norm=vision_post_norm,
activation=activation,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout,
norm_eps=norm_eps)
self.textual = XLMRobertaWithHead(
vocab_size=vocab_size,
max_seq_len=max_text_len,
type_size=type_size,
pad_id=pad_id,
dim=text_dim,
out_dim=embed_dim,
num_heads=text_heads,
num_layers=text_layers,
post_norm=text_post_norm,
dropout=text_dropout)
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
def forward(self, imgs, txt_ids):
"""
imgs: [B, 3, H, W] of torch.float32.
- mean: [0.48145466, 0.4578275, 0.40821073]
- std: [0.26862954, 0.26130258, 0.27577711]
txt_ids: [B, L] of torch.long.
Encoded by data.CLIPTokenizer.
"""
xi = self.visual(imgs)
xt = self.textual(txt_ids)
return xi, xt
def param_groups(self):
groups = [{
'params': [
p for n, p in self.named_parameters()
if 'norm' in n or n.endswith('bias')
],
'weight_decay': 0.0
}, {
'params': [
p for n, p in self.named_parameters()
if not ('norm' in n or n.endswith('bias'))
]
}]
return groups
def _clip(pretrained=False,
pretrained_name=None,
model_cls=XLMRobertaCLIP,
return_transforms=False,
return_tokenizer=False,
tokenizer_padding='eos',
dtype=torch.float32,
device='cpu',
**kwargs):
# init a model on device
with torch.device(device):
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
output = (model,)
# init transforms
if return_transforms:
# mean and std
if 'siglip' in pretrained_name.lower():
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
else:
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
# transforms
transforms = T.Compose([
T.Resize((model.image_size, model.image_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=mean, std=std)
])
output += (transforms,)
return output[0] if len(output) == 1 else output
def clip_xlm_roberta_vit_h_14(
pretrained=False,
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
**kwargs):
cfg = dict(
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool='token',
activation='gelu',
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0)
cfg.update(**kwargs)
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
self.dtype = dtype
self.device = device
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
pretrained=False,
return_transforms=True,
return_tokenizer=False,
dtype=dtype,
device=device)
self.model = self.model.eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}')
self.model.load_state_dict(
torch.load(checkpoint_path, map_location='cpu'))
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path,
seq_len=self.model.max_text_len - 2,
clean='whitespace')
def visual(self, videos):
# preprocess
size = (self.model.image_size,) * 2
videos = torch.cat([
F.interpolate(
u.transpose(0, 1),
size=size,
mode='bicubic',
align_corners=False) for u in videos
])
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
with torch.cuda.amp.autocast(dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True)
return out
\ No newline at end of file
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from torch import nn
import torch
from typing import Tuple, Optional
from einops import rearrange
import torch.nn.functional as F
import math
from ...distributed.util import gather_forward, get_rank, get_world_size
try:
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
except ImportError:
flash_attn_func = None
MEMORY_LAYOUT = {
"flash": (
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
lambda x: x,
),
"torch": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
"vanilla": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
}
def attention(
q,
k,
v,
mode="flash",
drop_rate=0,
attn_mask=None,
causal=False,
max_seqlen_q=None,
batch_size=1,
):
"""
Perform QKV self attention.
Args:
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
drop_rate (float): Dropout rate in attention map. (default: 0)
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
(default: None)
causal (bool): Whether to use causal attention. (default: False)
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into q.
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into kv.
max_seqlen_q (int): The maximum sequence length in the batch of q.
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
Returns:
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
"""
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
if mode == "torch":
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
elif mode == "flash":
x = flash_attn_func(
q,
k,
v,
)
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
elif mode == "vanilla":
scale_factor = 1 / math.sqrt(q.size(-1))
b, a, s, _ = q.shape
s1 = k.size(2)
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
if causal:
# Only applied to self attention
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(q.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn = (q @ k.transpose(-2, -1)) * scale_factor
attn += attn_bias
attn = attn.softmax(dim=-1)
attn = torch.dropout(attn, p=drop_rate, train=True)
x = attn @ v
else:
raise NotImplementedError(f"Unsupported attention mode: {mode}")
x = post_attn_layout(x)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
class CausalConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class FaceEncoder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
self.out_proj = nn.Linear(1024, hidden_dim)
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
def forward(self, x):
x = rearrange(x, "b t c -> b c t")
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv2(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv3(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm3(x)
x = self.act(x)
x = self.out_proj(x)
x = rearrange(x, "(b n) t c -> b t n c", b=b)
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
return x_local
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return nn.LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
class FaceAdapter(nn.Module):
def __init__(
self,
hidden_dim: int,
heads_num: int,
qk_norm: bool = True,
qk_norm_type: str = "rms",
num_adapter_layers: int = 1,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.hidden_size = hidden_dim
self.heads_num = heads_num
self.fuser_blocks = nn.ModuleList(
[
FaceBlock(
self.hidden_size,
self.heads_num,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
**factory_kwargs,
)
for _ in range(num_adapter_layers)
]
)
def forward(
self,
x: torch.Tensor,
motion_embed: torch.Tensor,
idx: int,
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
) -> torch.Tensor:
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
class FaceBlock(nn.Module):
def __init__(
self,
hidden_size: int,
heads_num: int,
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
self.scale = qk_scale or head_dim**-0.5
self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
def forward(
self,
x: torch.Tensor,
motion_vec: torch.Tensor,
motion_mask: Optional[torch.Tensor] = None,
use_context_parallel=False,
) -> torch.Tensor:
B, T, N, C = motion_vec.shape
T_comp = T
x_motion = self.pre_norm_motion(motion_vec)
x_feat = self.pre_norm_feat(x)
kv = self.linear1_kv(x_motion)
q = self.linear1_q(x_feat)
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
# Apply QK-Norm if needed.
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
k = rearrange(k, "B L N H D -> (B L) N H D")
v = rearrange(v, "B L N H D -> (B L) N H D")
if use_context_parallel:
q = gather_forward(q, dim=1)
q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
# Compute attention.
attn = attention(
q,
k,
v,
max_seqlen_q=q.shape[1],
batch_size=q.shape[0],
)
attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
if use_context_parallel:
attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
output = self.linear2(attn)
if motion_mask is not None:
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
return output
\ No newline at end of file
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import types
from copy import deepcopy
from einops import rearrange
from typing import List
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.loaders import PeftAdapterMixin
from ...distributed.sequence_parallel import (
distributed_attention,
gather_forward,
get_rank,
get_world_size,
)
from ..model import (
Head,
WanAttentionBlock,
WanLayerNorm,
WanRMSNorm,
WanModel,
WanSelfAttention,
flash_attention,
rope_params,
sinusoidal_embedding_1d,
rope_apply
)
from .face_blocks import FaceEncoder, FaceAdapter
from .motion_encoder import Generator
class HeadAnimate(Head):
def forward(self, x, e):
"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, L1, C]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x
class WanAnimateSelfAttention(WanSelfAttention):
def forward(self, x, seq_lens, grid_sizes, freqs):
"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanAnimateCrossAttention(WanSelfAttention):
def __init__(
self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6,
use_img_emb=True
):
super().__init__(
dim,
num_heads,
window_size,
qk_norm,
eps
)
self.use_img_emb = use_img_emb
if use_img_emb:
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, context, context_lens):
"""
x: [B, L1, C].
context: [B, L2, C].
context_lens: [B].
"""
if self.use_img_emb:
context_img = context[:, :257]
context = context[:, 257:]
else:
context = context
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
if self.use_img_emb:
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
v_img = self.v_img(context_img).view(b, -1, n, d)
img_x = flash_attention(q, k_img, v_img, k_lens=None)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
# output
x = x.flatten(2)
if self.use_img_emb:
img_x = img_x.flatten(2)
x = x + img_x
x = self.o(x)
return x
class WanAnimateAttentionBlock(nn.Module):
def __init__(self,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
use_img_emb=True):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanAnimateSelfAttention(dim, num_heads, window_size, qk_norm, eps)
self.norm3 = WanLayerNorm(
dim, eps, elementwise_affine=True
) if cross_attn_norm else nn.Identity()
self.cross_attn = WanAnimateCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps, use_img_emb=use_img_emb)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim)
)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)
def forward(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
):
"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, L1, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs
)
with amp.autocast(dtype=torch.float32):
x = x + y * e[2]
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
with amp.autocast(dtype=torch.float32):
x = x + y * e[5]
return x
x = cross_attn_ffn(x, context, context_lens, e)
return x
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.LayerNorm(in_dim),
torch.nn.Linear(in_dim, in_dim),
torch.nn.GELU(),
torch.nn.Linear(in_dim, out_dim),
torch.nn.LayerNorm(out_dim),
)
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_no_split_modules = ['WanAttentionBlock']
@register_to_config
def __init__(self,
patch_size=(1, 2, 2),
text_len=512,
in_dim=36,
dim=5120,
ffn_dim=13824,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=40,
num_layers=40,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
motion_encoder_dim=512,
use_context_parallel=False,
use_img_emb=True):
super().__init__()
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
self.motion_encoder_dim = motion_encoder_dim
self.use_context_parallel = use_context_parallel
self.use_img_emb = use_img_emb
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.pose_patch_embedding = nn.Conv3d(
16, dim, kernel_size=patch_size, stride=patch_size
)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList([
WanAnimateAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
cross_attn_norm, eps, use_img_emb) for _ in range(num_layers)
])
# head
self.head = HeadAnimate(dim, out_dim, patch_size, eps)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
], dim=1)
self.img_emb = MLPProj(1280, dim)
# initialize weights
self.init_weights()
self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
self.face_adapter = FaceAdapter(
heads_num=self.num_heads,
hidden_dim=self.dim,
num_adapter_layers=self.num_layers // 5,
)
self.face_encoder = FaceEncoder(
in_dim=motion_encoder_dim,
hidden_dim=self.dim,
num_heads=4,
)
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents]
for x_, pose_latents_ in zip(x, pose_latents):
x_[:, :, 1:] += pose_latents_
b,c,T,h,w = face_pixel_values.shape
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
encode_bs = 8
face_pixel_values_tmp = []
for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
motion_vec = torch.cat(face_pixel_values_tmp)
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
motion_vec = self.face_encoder(motion_vec)
B, L, H, C = motion_vec.shape
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
return x, motion_vec
def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):
if block_idx % 5 == 0:
adapter_args = [x, motion_vec, motion_masks, self.use_context_parallel]
residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
x = residual_out + x
return x
def forward(
self,
x,
t,
clip_fea,
context,
seq_len,
y=None,
pose_latents=None,
face_pixel_values=None
):
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float()
)
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if self.use_img_emb:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
if self.use_context_parallel:
x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
for idx, block in enumerate(self.blocks):
x = block(x, **kwargs)
x = self.after_transformer_block(idx, x, motion_vec)
# head
x = self.head(x, e)
if self.use_context_parallel:
x = gather_forward(x, dim=1)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
def init_weights(self):
r"""
Initialize model parameters using Xavier initialization.
"""
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
for m in self.text_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
for m in self.time_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
# init output layer
nn.init.zeros_(self.head.head.weight)
# Modified from ``https://github.com/wyhsirius/LIA``
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
def custom_qr(input_tensor):
original_dtype = input_tensor.dtype
if original_dtype == torch.bfloat16:
q, r = torch.linalg.qr(input_tensor.to(torch.float32))
return q.to(original_dtype), r.to(original_dtype)
return torch.linalg.qr(input_tensor)
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
return F.leaky_relu(input + bias, negative_slope) * scale
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, minor, in_h, in_w = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, minor, in_h, 1, in_w, 1)
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
return out[:, :, ::down_y, ::down_x]
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
return out
class Blur(nn.Module):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2)
self.register_buffer('kernel', kernel)
self.pad = pad
def forward(self, input):
return upfirdn2d(input, self.kernel, pad=self.pad)
class ScaledLeakyReLU(nn.Module):
def __init__(self, negative_slope=0.2):
super().__init__()
self.negative_slope = negative_slope
def forward(self, input):
return F.leaky_relu(input, negative_slope=self.negative_slope)
class EqualConv2d(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
self.stride = stride
self.padding = padding
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
)
class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
return out
def __repr__(self):
return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
class ConvLayer(nn.Sequential):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
bias=bias and not activate))
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channel))
else:
layers.append(ScaledLeakyReLU(0.2))
super().__init__(*layers)
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip) / math.sqrt(2)
return out
class EncoderApp(nn.Module):
def __init__(self, size, w_dim=512):
super(EncoderApp, self).__init__()
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256,
128: 128,
256: 64,
512: 32,
1024: 16
}
self.w_dim = w_dim
log_size = int(math.log(size, 2))
self.convs = nn.ModuleList()
self.convs.append(ConvLayer(3, channels[size], 1))
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
self.convs.append(ResBlock(in_channel, out_channel))
in_channel = out_channel
self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
def forward(self, x):
res = []
h = x
for conv in self.convs:
h = conv(h)
res.append(h)
return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
class Encoder(nn.Module):
def __init__(self, size, dim=512, dim_motion=20):
super(Encoder, self).__init__()
# appearance netmork
self.net_app = EncoderApp(size, dim)
# motion network
fc = [EqualLinear(dim, dim)]
for i in range(3):
fc.append(EqualLinear(dim, dim))
fc.append(EqualLinear(dim, dim_motion))
self.fc = nn.Sequential(*fc)
def enc_app(self, x):
h_source = self.net_app(x)
return h_source
def enc_motion(self, x):
h, _ = self.net_app(x)
h_motion = self.fc(h)
return h_motion
class Direction(nn.Module):
def __init__(self, motion_dim):
super(Direction, self).__init__()
self.weight = nn.Parameter(torch.randn(512, motion_dim))
def forward(self, input):
weight = self.weight + 1e-8
Q, R = custom_qr(weight)
if input is None:
return Q
else:
input_diag = torch.diag_embed(input) # alpha, diagonal matrix
out = torch.matmul(input_diag, Q.T)
out = torch.sum(out, dim=1)
return out
class Synthesis(nn.Module):
def __init__(self, motion_dim):
super(Synthesis, self).__init__()
self.direction = Direction(motion_dim)
class Generator(nn.Module):
def __init__(self, size, style_dim=512, motion_dim=20):
super().__init__()
self.enc = Encoder(size, style_dim, motion_dim)
self.dec = Synthesis(motion_dim)
def get_motion(self, img):
#motion_feat = self.enc.enc_motion(img)
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
with torch.cuda.amp.autocast(dtype=torch.float32):
motion = self.dec.direction(motion_feat)
return motion
\ No newline at end of file
# Wan-animate Preprocessing User Guider
## 1. Introductions
Wan-animate offers two generation modes: `animation` and `replacement`. While both modes extract the skeleton from the reference video, they each have a distinct preprocessing pipeline.
### 1.1 Animation Mode
In this mode, it is highly recommended to enable pose retargeting, especially if the body proportions of the reference and driving characters are dissimilar.
- A simplified version of pose retargeting pipeline is provided to help developers quickly implement this functionality.
- **NOTE:** Due to the potential complexity of input data, the results from this simplified retargeting version are NOT guaranteed to be perfect. It is strongly advised to verify the preprocessing results before proceeding.
- Community contributions to improve on this feature are welcome.
### 1.2 Replacement Mode
- Pose retargeting is DISABLED by default in this mode. This is a deliberate choice to account for potential spatial interactions between the character and the environment.
- **WARNING**: If there is a significant mismatch in body proportions between the reference and driving characters, artifacts or deformations may appear in the final output.
- A simplified version for extracting the character's mask is also provided.
- **WARNING:** This mask extraction process is designed for **single-person videos ONLY** and may produce incorrect results or fail in multi-person videos (incorrect pose tracking). For multi-person video, users are required to either develop their own solution or integrate a suitable open-source tool.
---
## 2. Preprocessing Instructions and Recommendations
### 2.1 Basic Usage
- The preprocessing process requires some additional models, including pose detection (mandatory), and mask extraction and image editing models (optional, as needed). Place them according to the following directory structure:
```
/path/to/your/ckpt_path/
├── det/
│ └── yolov10m.onnx
├── pose2d/
│ └── vitpose_h_wholebody.onnx
├── sam2/
│ └── sam2_hiera_large.pt
└── FLUX.1-Kontext-dev/
```
- `video_path`, `refer_path`, and `save_path` correspond to the paths for the input driving video, the character image, and the preprocessed results.
- When using `animation` mode, two videos, `src_face.mp4` and `src_pose.mp4`, will be generated in `save_path`. When using `replacement` mode, two additional videos, `src_bg.mp4` and `src_mask.mp4`, will also be generated.
- The `resolution_area` parameter determines the resolution for both preprocessing and the generation model. Its size is determined by pixel area.
- The `fps` parameter can specify the frame rate for video processing. A lower frame rate can improve generation efficiency, but may cause stuttering or choppiness.
---
### 2.2 Animation Mode
- We support three forms: not using pose retargeting, using basic pose retargeting, and using enhanced pose retargeting based on the `FLUX.1-Kontext-dev` image editing model. These are specified via the `retarget_flag` and `use_flux` parameters.
- Specifying `retarget_flag` to use basic pose retargeting requires ensuring that both the reference character and the character in the first frame of the driving video are in a front-facing, stretched pose.
- Other than that, we recommend using enhanced pose retargeting by specifying both `retarget_flag` and `use_flux`. **NOTE:** Due to the limited capabilities of `FLUX.1-Kontext-dev`, it is NOT guaranteed to produce the expected results (e.g., consistency is not maintained, the pose is incorrect, etc.). It is recommended to check the intermediate results as well as the finally generated pose video; both are stored in `save_path`. Of course, users can also use a better image editing model, or explore the prompts for Flux on their own.
---
### 2.3 Replacement Mode
- Specifying `replace_flag` to enable data preprocessing for this mode. The preprocessing will additionally process a mask for the character in the video, and its size and shape can be adjusted by specifying some parameters.
- `iterations` and `k` can make the mask larger, covering more area.
- `w_len` and `h_len` can adjust the mask's shape. Smaller values will make the outline coarser, while larger values will make it finer.
- A smaller, finer-contoured mask can allow for more of the original background to be preserved, but may potentially limit the character's generation area (considering potential appearance differences, this can lead to some shape leakage). A larger, coarser mask can allow the character generation to be more flexible and consistent, but because it includes more of the background, it might affect the background's consistency. We recommend users to adjust the relevant parameters based on their specific input data.
\ No newline at end of file
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .process_pipepline import ProcessPipeline
from .video_predictor import SAM2VideoPredictor
\ No newline at end of file
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import cv2
import time
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, List
import random
from pose2d_utils import AAPoseMeta
def draw_handpose(canvas, keypoints, hand_score_th=0.6):
"""
Draw keypoints and connections representing hand pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
or None if no keypoints are present.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
eps = 0.01
H, W, C = canvas.shape
stickwidth = max(int(min(H, W) / 200), 1)
edges = [
[0, 1],
[1, 2],
[2, 3],
[3, 4],
[0, 5],
[5, 6],
[6, 7],
[7, 8],
[0, 9],
[9, 10],
[10, 11],
[11, 12],
[0, 13],
[13, 14],
[14, 15],
[15, 16],
[0, 17],
[17, 18],
[18, 19],
[19, 20],
]
for ie, (e1, e2) in enumerate(edges):
k1 = keypoints[e1]
k2 = keypoints[e2]
if k1 is None or k2 is None:
continue
if k1[2] < hand_score_th or k2[2] < hand_score_th:
continue
x1 = int(k1[0])
y1 = int(k1[1])
x2 = int(k2[0])
y2 = int(k2[1])
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
cv2.line(
canvas,
(x1, y1),
(x2, y2),
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
thickness=stickwidth,
)
for keypoint in keypoints:
if keypoint is None:
continue
if keypoint[2] < hand_score_th:
continue
x, y = keypoint[0], keypoint[1]
x = int(x)
y = int(y)
if x > eps and y > eps:
cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1)
return canvas
def draw_handpose_new(canvas, keypoints, stickwidth_type='v2', hand_score_th=0.6):
"""
Draw keypoints and connections representing hand pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
or None if no keypoints are present.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
eps = 0.01
H, W, C = canvas.shape
if stickwidth_type == 'v1':
stickwidth = max(int(min(H, W) / 200), 1)
elif stickwidth_type == 'v2':
stickwidth = max(max(int(min(H, W) / 200) - 1, 1) // 2, 1)
edges = [
[0, 1],
[1, 2],
[2, 3],
[3, 4],
[0, 5],
[5, 6],
[6, 7],
[7, 8],
[0, 9],
[9, 10],
[10, 11],
[11, 12],
[0, 13],
[13, 14],
[14, 15],
[15, 16],
[0, 17],
[17, 18],
[18, 19],
[19, 20],
]
for ie, (e1, e2) in enumerate(edges):
k1 = keypoints[e1]
k2 = keypoints[e2]
if k1 is None or k2 is None:
continue
if k1[2] < hand_score_th or k2[2] < hand_score_th:
continue
x1 = int(k1[0])
y1 = int(k1[1])
x2 = int(k2[0])
y2 = int(k2[1])
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
cv2.line(
canvas,
(x1, y1),
(x2, y2),
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
thickness=stickwidth,
)
for keypoint in keypoints:
if keypoint is None:
continue
if keypoint[2] < hand_score_th:
continue
x, y = keypoint[0], keypoint[1]
x = int(x)
y = int(y)
if x > eps and y > eps:
cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1)
return canvas
def draw_ellipse_by_2kp(img, keypoint1, keypoint2, color, threshold=0.6):
H, W, C = img.shape
stickwidth = max(int(min(H, W) / 200), 1)
if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
return img
Y = np.array([keypoint1[0], keypoint2[0]])
X = np.array([keypoint1[1], keypoint2[1]])
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
return img
def split_pose2d_kps_to_aa(kp2ds: np.ndarray) -> List[np.ndarray]:
"""Convert the 133 keypoints from pose2d to body and hands keypoints.
Args:
kp2ds (np.ndarray): [133, 2]
Returns:
List[np.ndarray]: _description_
"""
kp2ds_body = (
kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]]
+ kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]
) / 2
kp2ds_lhand = kp2ds[91:112]
kp2ds_rhand = kp2ds[112:133]
return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy()
def draw_aapose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200, draw_hand=True, draw_head=True):
kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
pose_img = draw_aapose(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stick_width_norm=stick_width_norm, draw_hand=draw_hand, draw_head=draw_head)
return pose_img
def draw_aapose_by_meta_new(img, meta: AAPoseMeta, threshold=0.5, stickwidth_type='v2', draw_hand=True, draw_head=True):
kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
pose_img = draw_aapose_new(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand,
stickwidth_type=stickwidth_type, draw_hand=draw_hand, draw_head=draw_head)
return pose_img
def draw_hand_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200):
kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None] * 0], axis=1)
kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
pose_img = draw_aapose(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stick_width_norm=stick_width_norm, draw_hand=True, draw_head=False)
return pose_img
def draw_aaface_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200, draw_hand=False, draw_head=True):
kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
# kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
# kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
pose_img = draw_M(img, kp2ds, threshold, kp2ds_lhand=None, kp2ds_rhand=None, stick_width_norm=stick_width_norm, draw_hand=draw_hand, draw_head=draw_head)
return pose_img
def draw_aanose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=100, draw_hand=False):
kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
# kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
# kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
pose_img = draw_nose(img, kp2ds, threshold, kp2ds_lhand=None, kp2ds_rhand=None, stick_width_norm=stick_width_norm, draw_hand=draw_hand)
return pose_img
def gen_face_motion_seq(img, metas: List[AAPoseMeta], threshold=0.5, stick_width_norm=200):
return
def draw_M(
img,
kp2ds,
threshold=0.6,
data_to_json=None,
idx=-1,
kp2ds_lhand=None,
kp2ds_rhand=None,
draw_hand=False,
stick_width_norm=200,
draw_head=True
):
"""
Draw keypoints and connections representing hand pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
or None if no keypoints are present.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
new_kep_list = [
"Nose",
"Neck",
"RShoulder",
"RElbow",
"RWrist", # No.4
"LShoulder",
"LElbow",
"LWrist", # No.7
"RHip",
"RKnee",
"RAnkle", # No.10
"LHip",
"LKnee",
"LAnkle", # No.13
"REye",
"LEye",
"REar",
"LEar",
"LToe",
"RToe",
]
# kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
# kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
kp2ds = kp2ds.copy()
# import ipdb; ipdb.set_trace()
kp2ds[[1,2,3,4,5,6,7,8,9,10,11,12,13,18,19], 2] = 0
if not draw_head:
kp2ds[[0,14,15,16,17], 2] = 0
kp2ds_body = kp2ds
# kp2ds_body = kp2ds_body[:18]
# kp2ds_lhand = kp2ds.copy()[91:112]
# kp2ds_rhand = kp2ds.copy()[112:133]
limbSeq = [
# [2, 3],
# [2, 6], # shoulders
# [3, 4],
# [4, 5], # left arm
# [6, 7],
# [7, 8], # right arm
# [2, 9],
# [9, 10],
# [10, 11], # right leg
# [2, 12],
# [12, 13],
# [13, 14], # left leg
# [2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18], # face (nose, eyes, ears)
# [14, 19],
# [11, 20], # foot
]
colors = [
# [255, 0, 0],
# [255, 85, 0],
# [255, 170, 0],
# [255, 255, 0],
# [170, 255, 0],
# [85, 255, 0],
# [0, 255, 0],
# [0, 255, 85],
# [0, 255, 170],
# [0, 255, 255],
# [0, 170, 255],
# [0, 85, 255],
# [0, 0, 255],
# [85, 0, 255],
[170, 0, 255],
[255, 0, 255],
[255, 0, 170],
[255, 0, 85],
# foot
# [200, 200, 0],
# [100, 100, 0],
]
H, W, C = img.shape
stickwidth = max(int(min(H, W) / stick_width_norm), 1)
for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
keypoint1 = kp2ds_body[k1_index - 1]
keypoint2 = kp2ds_body[k2_index - 1]
if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
continue
Y = np.array([keypoint1[0], keypoint2[0]])
X = np.array([keypoint1[1], keypoint2[1]])
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
if keypoint[-1] < threshold:
continue
x, y = keypoint[0], keypoint[1]
# cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
if draw_hand:
img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)
img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)
kp2ds_body[:, 0] /= W
kp2ds_body[:, 1] /= H
if data_to_json is not None:
if idx == -1:
data_to_json.append(
{
"image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
"height": H,
"width": W,
"category_id": 1,
"keypoints_body": kp2ds_body.tolist(),
"keypoints_left_hand": kp2ds_lhand.tolist(),
"keypoints_right_hand": kp2ds_rhand.tolist(),
}
)
else:
data_to_json[idx] = {
"image_id": "frame_{:05d}.jpg".format(idx + 1),
"height": H,
"width": W,
"category_id": 1,
"keypoints_body": kp2ds_body.tolist(),
"keypoints_left_hand": kp2ds_lhand.tolist(),
"keypoints_right_hand": kp2ds_rhand.tolist(),
}
return img
def draw_nose(
img,
kp2ds,
threshold=0.6,
data_to_json=None,
idx=-1,
kp2ds_lhand=None,
kp2ds_rhand=None,
draw_hand=False,
stick_width_norm=200,
):
"""
Draw keypoints and connections representing hand pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
or None if no keypoints are present.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
new_kep_list = [
"Nose",
"Neck",
"RShoulder",
"RElbow",
"RWrist", # No.4
"LShoulder",
"LElbow",
"LWrist", # No.7
"RHip",
"RKnee",
"RAnkle", # No.10
"LHip",
"LKnee",
"LAnkle", # No.13
"REye",
"LEye",
"REar",
"LEar",
"LToe",
"RToe",
]
# kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
# kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
kp2ds = kp2ds.copy()
kp2ds[1:, 2] = 0
# kp2ds[0, 2] = 1
kp2ds_body = kp2ds
# kp2ds_body = kp2ds_body[:18]
# kp2ds_lhand = kp2ds.copy()[91:112]
# kp2ds_rhand = kp2ds.copy()[112:133]
limbSeq = [
# [2, 3],
# [2, 6], # shoulders
# [3, 4],
# [4, 5], # left arm
# [6, 7],
# [7, 8], # right arm
# [2, 9],
# [9, 10],
# [10, 11], # right leg
# [2, 12],
# [12, 13],
# [13, 14], # left leg
# [2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18], # face (nose, eyes, ears)
# [14, 19],
# [11, 20], # foot
]
colors = [
# [255, 0, 0],
# [255, 85, 0],
# [255, 170, 0],
# [255, 255, 0],
# [170, 255, 0],
# [85, 255, 0],
# [0, 255, 0],
# [0, 255, 85],
# [0, 255, 170],
# [0, 255, 255],
# [0, 170, 255],
# [0, 85, 255],
# [0, 0, 255],
# [85, 0, 255],
[170, 0, 255],
# [255, 0, 255],
# [255, 0, 170],
# [255, 0, 85],
# foot
# [200, 200, 0],
# [100, 100, 0],
]
H, W, C = img.shape
stickwidth = max(int(min(H, W) / stick_width_norm), 1)
# for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
# keypoint1 = kp2ds_body[k1_index - 1]
# keypoint2 = kp2ds_body[k2_index - 1]
# if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
# continue
# Y = np.array([keypoint1[0], keypoint2[0]])
# X = np.array([keypoint1[1], keypoint2[1]])
# mX = np.mean(X)
# mY = np.mean(Y)
# length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
# angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
# polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
# cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
if keypoint[-1] < threshold:
continue
x, y = keypoint[0], keypoint[1]
# cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
if draw_hand:
img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)
img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)
kp2ds_body[:, 0] /= W
kp2ds_body[:, 1] /= H
if data_to_json is not None:
if idx == -1:
data_to_json.append(
{
"image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
"height": H,
"width": W,
"category_id": 1,
"keypoints_body": kp2ds_body.tolist(),
"keypoints_left_hand": kp2ds_lhand.tolist(),
"keypoints_right_hand": kp2ds_rhand.tolist(),
}
)
else:
data_to_json[idx] = {
"image_id": "frame_{:05d}.jpg".format(idx + 1),
"height": H,
"width": W,
"category_id": 1,
"keypoints_body": kp2ds_body.tolist(),
"keypoints_left_hand": kp2ds_lhand.tolist(),
"keypoints_right_hand": kp2ds_rhand.tolist(),
}
return img
def draw_aapose(
img,
kp2ds,
threshold=0.6,
data_to_json=None,
idx=-1,
kp2ds_lhand=None,
kp2ds_rhand=None,
draw_hand=False,
stick_width_norm=200,
draw_head=True
):
"""
Draw keypoints and connections representing hand pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
or None if no keypoints are present.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
new_kep_list = [
"Nose",
"Neck",
"RShoulder",
"RElbow",
"RWrist", # No.4
"LShoulder",
"LElbow",
"LWrist", # No.7
"RHip",
"RKnee",
"RAnkle", # No.10
"LHip",
"LKnee",
"LAnkle", # No.13
"REye",
"LEye",
"REar",
"LEar",
"LToe",
"RToe",
]
# kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
# kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
kp2ds = kp2ds.copy()
if not draw_head:
kp2ds[[0,14,15,16,17], 2] = 0
kp2ds_body = kp2ds
# kp2ds_lhand = kp2ds.copy()[91:112]
# kp2ds_rhand = kp2ds.copy()[112:133]
limbSeq = [
[2, 3],
[2, 6], # shoulders
[3, 4],
[4, 5], # left arm
[6, 7],
[7, 8], # right arm
[2, 9],
[9, 10],
[10, 11], # right leg
[2, 12],
[12, 13],
[13, 14], # left leg
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18], # face (nose, eyes, ears)
[14, 19],
[11, 20], # foot
]
colors = [
[255, 0, 0],
[255, 85, 0],
[255, 170, 0],
[255, 255, 0],
[170, 255, 0],
[85, 255, 0],
[0, 255, 0],
[0, 255, 85],
[0, 255, 170],
[0, 255, 255],
[0, 170, 255],
[0, 85, 255],
[0, 0, 255],
[85, 0, 255],
[170, 0, 255],
[255, 0, 255],
[255, 0, 170],
[255, 0, 85],
# foot
[200, 200, 0],
[100, 100, 0],
]
H, W, C = img.shape
stickwidth = max(int(min(H, W) / stick_width_norm), 1)
for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
keypoint1 = kp2ds_body[k1_index - 1]
keypoint2 = kp2ds_body[k2_index - 1]
if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
continue
Y = np.array([keypoint1[0], keypoint2[0]])
X = np.array([keypoint1[1], keypoint2[1]])
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
if keypoint[-1] < threshold:
continue
x, y = keypoint[0], keypoint[1]
# cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
if draw_hand:
img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)
img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)
kp2ds_body[:, 0] /= W
kp2ds_body[:, 1] /= H
if data_to_json is not None:
if idx == -1:
data_to_json.append(
{
"image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
"height": H,
"width": W,
"category_id": 1,
"keypoints_body": kp2ds_body.tolist(),
"keypoints_left_hand": kp2ds_lhand.tolist(),
"keypoints_right_hand": kp2ds_rhand.tolist(),
}
)
else:
data_to_json[idx] = {
"image_id": "frame_{:05d}.jpg".format(idx + 1),
"height": H,
"width": W,
"category_id": 1,
"keypoints_body": kp2ds_body.tolist(),
"keypoints_left_hand": kp2ds_lhand.tolist(),
"keypoints_right_hand": kp2ds_rhand.tolist(),
}
return img
def draw_aapose_new(
img,
kp2ds,
threshold=0.6,
data_to_json=None,
idx=-1,
kp2ds_lhand=None,
kp2ds_rhand=None,
draw_hand=False,
stickwidth_type='v2',
draw_head=True
):
"""
Draw keypoints and connections representing hand pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
or None if no keypoints are present.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
new_kep_list = [
"Nose",
"Neck",
"RShoulder",
"RElbow",
"RWrist", # No.4
"LShoulder",
"LElbow",
"LWrist", # No.7
"RHip",
"RKnee",
"RAnkle", # No.10
"LHip",
"LKnee",
"LAnkle", # No.13
"REye",
"LEye",
"REar",
"LEar",
"LToe",
"RToe",
]
# kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
# kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
kp2ds = kp2ds.copy()
if not draw_head:
kp2ds[[0,14,15,16,17], 2] = 0
kp2ds_body = kp2ds
# kp2ds_lhand = kp2ds.copy()[91:112]
# kp2ds_rhand = kp2ds.copy()[112:133]
limbSeq = [
[2, 3],
[2, 6], # shoulders
[3, 4],
[4, 5], # left arm
[6, 7],
[7, 8], # right arm
[2, 9],
[9, 10],
[10, 11], # right leg
[2, 12],
[12, 13],
[13, 14], # left leg
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18], # face (nose, eyes, ears)
[14, 19],
[11, 20], # foot
]
colors = [
[255, 0, 0],
[255, 85, 0],
[255, 170, 0],
[255, 255, 0],
[170, 255, 0],
[85, 255, 0],
[0, 255, 0],
[0, 255, 85],
[0, 255, 170],
[0, 255, 255],
[0, 170, 255],
[0, 85, 255],
[0, 0, 255],
[85, 0, 255],
[170, 0, 255],
[255, 0, 255],
[255, 0, 170],
[255, 0, 85],
# foot
[200, 200, 0],
[100, 100, 0],
]
H, W, C = img.shape
H, W, C = img.shape
if stickwidth_type == 'v1':
stickwidth = max(int(min(H, W) / 200), 1)
elif stickwidth_type == 'v2':
stickwidth = max(int(min(H, W) / 200) - 1, 1)
else:
raise
for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
keypoint1 = kp2ds_body[k1_index - 1]
keypoint2 = kp2ds_body[k2_index - 1]
if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
continue
Y = np.array([keypoint1[0], keypoint2[0]])
X = np.array([keypoint1[1], keypoint2[1]])
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
if keypoint[-1] < threshold:
continue
x, y = keypoint[0], keypoint[1]
# cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
if draw_hand:
img = draw_handpose_new(img, kp2ds_lhand, stickwidth_type=stickwidth_type, hand_score_th=threshold)
img = draw_handpose_new(img, kp2ds_rhand, stickwidth_type=stickwidth_type, hand_score_th=threshold)
kp2ds_body[:, 0] /= W
kp2ds_body[:, 1] /= H
if data_to_json is not None:
if idx == -1:
data_to_json.append(
{
"image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
"height": H,
"width": W,
"category_id": 1,
"keypoints_body": kp2ds_body.tolist(),
"keypoints_left_hand": kp2ds_lhand.tolist(),
"keypoints_right_hand": kp2ds_rhand.tolist(),
}
)
else:
data_to_json[idx] = {
"image_id": "frame_{:05d}.jpg".format(idx + 1),
"height": H,
"width": W,
"category_id": 1,
"keypoints_body": kp2ds_body.tolist(),
"keypoints_left_hand": kp2ds_lhand.tolist(),
"keypoints_right_hand": kp2ds_rhand.tolist(),
}
return img
def draw_bbox(img, bbox, color=(255, 0, 0)):
img = load_image(img)
bbox = [int(bbox_tmp) for bbox_tmp in bbox]
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)
return img
def draw_kp2ds(img, kp2ds, threshold=0, color=(255, 0, 0), skeleton=None, reverse=False):
img = load_image(img, reverse)
if skeleton is not None:
if skeleton == "coco17":
skeleton_list = [
[6, 8],
[8, 10],
[5, 7],
[7, 9],
[11, 13],
[13, 15],
[12, 14],
[14, 16],
[5, 6],
[6, 12],
[12, 11],
[11, 5],
]
color_list = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
]
elif skeleton == "cocowholebody":
skeleton_list = [
[6, 8],
[8, 10],
[5, 7],
[7, 9],
[11, 13],
[13, 15],
[12, 14],
[14, 16],
[5, 6],
[6, 12],
[12, 11],
[11, 5],
[15, 17],
[15, 18],
[15, 19],
[16, 20],
[16, 21],
[16, 22],
[91, 92, 93, 94, 95],
[91, 96, 97, 98, 99],
[91, 100, 101, 102, 103],
[91, 104, 105, 106, 107],
[91, 108, 109, 110, 111],
[112, 113, 114, 115, 116],
[112, 117, 118, 119, 120],
[112, 121, 122, 123, 124],
[112, 125, 126, 127, 128],
[112, 129, 130, 131, 132],
]
color_list = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
]
else:
color_list = [color]
for _idx, _skeleton in enumerate(skeleton_list):
for i in range(len(_skeleton) - 1):
cv2.line(
img,
(int(kp2ds[_skeleton[i], 0]), int(kp2ds[_skeleton[i], 1])),
(int(kp2ds[_skeleton[i + 1], 0]), int(kp2ds[_skeleton[i + 1], 1])),
color_list[_idx % len(color_list)],
3,
)
for _idx, kp2d in enumerate(kp2ds):
if kp2d[2] > threshold:
cv2.circle(img, (int(kp2d[0]), int(kp2d[1])), 3, color, -1)
# cv2.putText(img,
# str(_idx),
# (int(kp2d[0, i, 0])*1,
# int(kp2d[0, i, 1])*1),
# cv2.FONT_HERSHEY_SIMPLEX,
# 0.75,
# color,
# 2
# )
return img
def draw_mask(img, mask, background=0, return_rgba=False):
img = load_image(img)
h, w, _ = img.shape
if type(background) == int:
background = np.ones((h, w, 3)).astype(np.uint8) * 255 * background
backgournd = cv2.resize(background, (w, h))
img_rgba = np.concatenate([img, mask], -1)
return alphaMerge(img_rgba, background, 0, 0, return_rgba=True)
def draw_pcd(pcd_list, save_path=None):
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
color_list = ["r", "g", "b", "y", "p"]
for _idx, _pcd in enumerate(pcd_list):
ax.scatter(_pcd[:, 0], _pcd[:, 1], _pcd[:, 2], c=color_list[_idx], marker="o")
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
if save_path is not None:
plt.savefig(save_path)
else:
plt.savefig("tmp.png")
def load_image(img, reverse=False):
if type(img) == str:
img = cv2.imread(img)
if reverse:
img = img.astype(np.float32)
img = img[:, :, ::-1]
img = img.astype(np.uint8)
return img
def draw_skeleten(meta):
kps = []
for i, kp in enumerate(meta["keypoints_body"]):
if kp is None:
# if kp is None:
kps.append([0, 0, 0])
else:
kps.append([*kp, 1])
kps = np.array(kps)
kps[:, 0] *= meta["width"]
kps[:, 1] *= meta["height"]
pose_img = np.zeros([meta["height"], meta["width"], 3], dtype=np.uint8)
pose_img = draw_aapose(
pose_img,
kps,
draw_hand=True,
kp2ds_lhand=meta["keypoints_left_hand"],
kp2ds_rhand=meta["keypoints_right_hand"],
)
return pose_img
def draw_skeleten_with_pncc(pncc: np.ndarray, meta: Dict) -> np.ndarray:
"""
Args:
pncc: [H,W,3]
meta: required keys: keypoints_body: [N, 3] keypoints_left_hand, keypoints_right_hand
Return:
np.ndarray [H, W, 3]
"""
# preprocess keypoints
kps = []
for i, kp in enumerate(meta["keypoints_body"]):
if kp is None:
# if kp is None:
kps.append([0, 0, 0])
elif i in [14, 15, 16, 17]:
kps.append([0, 0, 0])
else:
kps.append([*kp])
kps = np.stack(kps)
kps[:, 0] *= pncc.shape[1]
kps[:, 1] *= pncc.shape[0]
# draw neck
canvas = np.zeros_like(pncc)
if kps[0][2] > 0.6 and kps[1][2] > 0.6:
canvas = draw_ellipse_by_2kp(canvas, kps[0], kps[1], [0, 0, 255])
# draw pncc
mask = (pncc > 0).max(axis=2)
canvas[mask] = pncc[mask]
pncc = canvas
# draw other skeleten
kps[0] = 0
meta["keypoints_left_hand"][:, 0] *= meta["width"]
meta["keypoints_left_hand"][:, 1] *= meta["height"]
meta["keypoints_right_hand"][:, 0] *= meta["width"]
meta["keypoints_right_hand"][:, 1] *= meta["height"]
pose_img = draw_aapose(
pncc,
kps,
draw_hand=True,
kp2ds_lhand=meta["keypoints_left_hand"],
kp2ds_rhand=meta["keypoints_right_hand"],
)
return pose_img
FACE_CUSTOM_STYLE = {
"eyeball": {"indexs": [68, 69], "color": [255, 255, 255], "connect": False},
"left_eyebrow": {"indexs": [17, 18, 19, 20, 21], "color": [0, 255, 0]},
"right_eyebrow": {"indexs": [22, 23, 24, 25, 26], "color": [0, 0, 255]},
"left_eye": {"indexs": [36, 37, 38, 39, 40, 41], "color": [255, 255, 0], "close": True},
"right_eye": {"indexs": [42, 43, 44, 45, 46, 47], "color": [255, 0, 255], "close": True},
"mouth_outside": {"indexs": list(range(48, 60)), "color": [100, 255, 50], "close": True},
"mouth_inside": {"indexs": [60, 61, 62, 63, 64, 65, 66, 67], "color": [255, 100, 50], "close": True},
}
def draw_face_kp(img, kps, thickness=2, style=FACE_CUSTOM_STYLE):
"""
Args:
img: [H, W, 3]
kps: [70, 2]
"""
img = img.copy()
for key, item in style.items():
pts = np.array(kps[item["indexs"]]).astype(np.int32)
connect = item.get("connect", True)
color = item["color"]
close = item.get("close", False)
if connect:
cv2.polylines(img, [pts], close, color, thickness=thickness)
else:
for kp in pts:
kp = np.array(kp).astype(np.int32)
cv2.circle(img, kp, thickness * 2, color=color, thickness=-1)
return img
def draw_traj(metas: List[AAPoseMeta], threshold=0.6):
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [100, 255, 50], [255, 100, 50],
# foot
[200, 200, 0],
[100, 100, 0]
]
limbSeq = [
[1, 2], [1, 5], # shoulders
[2, 3], [3, 4], # left arm
[5, 6], [6, 7], # right arm
[1, 8], [8, 9], [9, 10], # right leg
[1, 11], [11, 12], [12, 13], # left leg
# face (nose, eyes, ears)
[13, 18], [10, 19] # foot
]
face_seq = [[1, 0], [0, 14], [14, 16], [0, 15], [15, 17]]
kp_body = np.array([meta.kps_body for meta in metas])
kp_body_p = np.array([meta.kps_body_p for meta in metas])
face_seq = random.sample(face_seq, 2)
kp_lh = np.array([meta.kps_lhand for meta in metas])
kp_rh = np.array([meta.kps_rhand for meta in metas])
kp_lh_p = np.array([meta.kps_lhand_p for meta in metas])
kp_rh_p = np.array([meta.kps_rhand_p for meta in metas])
# kp_lh = np.concatenate([kp_lh, kp_lh_p], axis=-1)
# kp_rh = np.concatenate([kp_rh, kp_rh_p], axis=-1)
new_limbSeq = []
key_point_list = []
for _idx, ((k1_index, k2_index)) in enumerate(limbSeq):
vis = (kp_body_p[:, k1_index] > threshold) * (kp_body_p[:, k2_index] > threshold) * 1
if vis.sum() * 1.0 / vis.shape[0] > 0.4:
new_limbSeq.append([k1_index, k2_index])
for _idx, ((k1_index, k2_index)) in enumerate(limbSeq):
keypoint1 = kp_body[:, k1_index - 1]
keypoint2 = kp_body[:, k2_index - 1]
interleave = random.randint(4, 7)
randind = random.randint(0, interleave - 1)
# randind = random.rand(range(interleave), sampling_num)
Y = np.array([keypoint1[:, 0], keypoint2[:, 0]])
X = np.array([keypoint1[:, 1], keypoint2[:, 1]])
vis = (keypoint1[:, -1] > threshold) * (keypoint2[:, -1] > threshold) * 1
# for randidx in randind:
t = randind / interleave
x = (1-t)*Y[0, :] + t*Y[1, :]
y = (1-t)*X[0, :] + t*X[1, :]
# np.array([1])
x = x.astype(int)
y = y.astype(int)
new_array = np.array([x, y, vis]).T
key_point_list.append(new_array)
indx_lh = random.randint(0, kp_lh.shape[1] - 1)
lh = kp_lh[:, indx_lh, :]
lh_p = kp_lh_p[:, indx_lh:indx_lh+1]
lh = np.concatenate([lh, lh_p], axis=-1)
indx_rh = random.randint(0, kp_rh.shape[1] - 1)
rh = kp_rh[:, random.randint(0, kp_rh.shape[1] - 1), :]
rh_p = kp_rh_p[:, indx_rh:indx_rh+1]
rh = np.concatenate([rh, rh_p], axis=-1)
lh[-1, :] = (lh[-1, :] > threshold) * 1
rh[-1, :] = (rh[-1, :] > threshold) * 1
# print(rh.shape, new_array.shape)
# exit()
key_point_list.append(lh.astype(int))
key_point_list.append(rh.astype(int))
key_points_list = np.stack(key_point_list)
num_points = len(key_points_list)
sample_colors = random.sample(colors, num_points)
stickwidth = max(int(min(metas[0].width, metas[0].height) / 150), 2)
image_list_ori = []
for i in range(key_points_list.shape[-2]):
_image_vis = np.zeros((metas[0].width, metas[0].height, 3))
points = key_points_list[:, i, :]
for idx, point in enumerate(points):
x, y, vis = point
if vis == 1:
cv2.circle(_image_vis, (x, y), stickwidth, sample_colors[idx], thickness=-1)
image_list_ori.append(_image_vis)
return image_list_ori
return [np.zeros([meta.width, meta.height, 3], dtype=np.uint8) for meta in metas]
if __name__ == "__main__":
meta = {
"image_id": "00472.jpg",
"height": 540,
"width": 414,
"category_id": 1,
"keypoints_body": [
[0.5084776947463768, 0.11350188078703703],
[0.504467655495169, 0.20419560185185184],
[0.3982016153381642, 0.198046875],
[0.3841664779589372, 0.34869068287037036],
[0.3901815368357488, 0.4670536747685185],
[0.610733695652174, 0.2103443287037037],
[0.6167487545289855, 0.3517650462962963],
[0.6448190292874396, 0.4762767650462963],
[0.4523371452294686, 0.47320240162037036],
[0.4503321256038647, 0.6776475694444445],
[0.47639738073671495, 0.8544234664351852],
[0.5766483620169082, 0.47320240162037036],
[0.5666232638888888, 0.6761103877314815],
[0.534542949879227, 0.863646556712963],
[0.4864224788647343, 0.09505570023148148],
[0.5285278910024155, 0.09351851851851851],
[0.46236224335748793, 0.10581597222222222],
[0.5586031853864735, 0.10274160879629629],
[0.4994551064311594, 0.9405056423611111],
[0.4152442821557971, 0.9312825520833333],
],
"keypoints_left_hand": [
[267.78515625, 263.830078125, 1.2840936183929443],
[265.294921875, 269.640625, 1.2546794414520264],
[263.634765625, 277.111328125, 1.2863062620162964],
[262.8046875, 285.412109375, 1.267038345336914],
[261.14453125, 292.8828125, 1.280144453048706],
[273.595703125, 281.26171875, 1.2592815160751343],
[271.10546875, 291.22265625, 1.3256099224090576],
[265.294921875, 294.54296875, 1.2368024587631226],
[261.14453125, 294.54296875, 0.9771889448165894],
[274.42578125, 282.091796875, 1.250044584274292],
[269.4453125, 291.22265625, 1.2571144104003906],
[264.46484375, 292.8828125, 1.177802324295044],
[260.314453125, 292.052734375, 0.9283463358879089],
[273.595703125, 282.091796875, 1.1834490299224854],
[269.4453125, 290.392578125, 1.188171625137329],
[265.294921875, 290.392578125, 1.192609429359436],
[261.974609375, 289.5625, 0.9366656541824341],
[271.935546875, 281.26171875, 1.0946396589279175],
[268.615234375, 287.072265625, 0.9906131029129028],
[265.294921875, 287.90234375, 1.0219476222991943],
[262.8046875, 287.072265625, 0.9240120053291321],
],
"keypoints_right_hand": [
[161.53515625, 258.849609375, 1.2069408893585205],
[168.17578125, 263.0, 1.1846840381622314],
[173.986328125, 269.640625, 1.1435924768447876],
[173.986328125, 277.94140625, 1.1802611351013184],
[173.986328125, 286.2421875, 1.2599592208862305],
[165.685546875, 275.451171875, 1.0633569955825806],
[167.345703125, 286.2421875, 1.1693341732025146],
[169.8359375, 291.22265625, 1.2698509693145752],
[170.666015625, 294.54296875, 1.0619274377822876],
[160.705078125, 276.28125, 1.0995020866394043],
[163.1953125, 287.90234375, 1.2735884189605713],
[166.515625, 291.22265625, 1.339503526687622],
[169.005859375, 294.54296875, 1.0835273265838623],
[157.384765625, 277.111328125, 1.0866981744766235],
[161.53515625, 287.072265625, 1.2468621730804443],
[164.025390625, 289.5625, 1.2817761898040771],
[166.515625, 292.052734375, 1.099466323852539],
[155.724609375, 277.111328125, 1.1065717935562134],
[159.044921875, 285.412109375, 1.1924479007720947],
[160.705078125, 287.072265625, 1.1304771900177002],
[162.365234375, 287.90234375, 1.0040509700775146],
],
}
demo_meta = AAPoseMeta(meta)
res = draw_traj([demo_meta]*5)
cv2.imwrite("traj.png", res[0][..., ::-1])
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import cv2
from typing import Union, List
import numpy as np
import torch
import onnxruntime
from pose2d_utils import (
read_img,
box_convert_simple,
bbox_from_detector,
crop,
keypoints_from_heatmaps,
load_pose_metas_from_kp2ds_seq
)
class SimpleOnnxInference(object):
def __init__(self, checkpoint, device='cuda', reverse_input=False, **kwargs):
if isinstance(device, str):
device = torch.device(device)
if device.type == 'cuda':
device = '{}:{}'.format(device.type, device.index)
providers = [("CUDAExecutionProvider", {"device_id": device[-1:] if device[-1] in [str(_i) for _i in range(10)] else "0"}), "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
self.device = device
if not os.path.exists(checkpoint):
raise RuntimeError("{} is not existed!".format(checkpoint))
if os.path.isdir(checkpoint):
checkpoint = os.path.join(checkpoint, 'end2end.onnx')
self.session = onnxruntime.InferenceSession(checkpoint,
providers=providers
)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
self.input_resolution = self.session.get_inputs()[0].shape[2:] if not reverse_input else self.session.get_inputs()[0].shape[2:][::-1]
self.input_resolution = np.array(self.input_resolution)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def get_output_names(self):
output_names = []
for node in self.session.get_outputs():
output_names.append(node.name)
return output_names
def set_device(self, device):
if isinstance(device, str):
device = torch.device(device)
if device.type == 'cuda':
device = '{}:{}'.format(device.type, device.index)
providers = [("CUDAExecutionProvider", {"device_id": device[-1:] if device[-1] in [str(_i) for _i in range(10)] else "0"}), "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
self.session.set_providers(providers)
self.device = device
class Yolo(SimpleOnnxInference):
def __init__(self, checkpoint, device='cuda', threshold_conf=0.05, threshold_multi_persons=0.1, input_resolution=(640, 640), threshold_iou=0.5, threshold_bbox_shape_ratio=0.4, cat_id=[1], select_type='max', strict=True, sorted_func=None, **kwargs):
super(Yolo, self).__init__(checkpoint, device=device, **kwargs)
model_inputs = self.session.get_inputs()
input_shape = model_inputs[0].shape
self.input_width = 640
self.input_height = 640
self.threshold_multi_persons = threshold_multi_persons
self.threshold_conf = threshold_conf
self.threshold_iou = threshold_iou
self.threshold_bbox_shape_ratio = threshold_bbox_shape_ratio
self.input_resolution = input_resolution
self.cat_id = cat_id
self.select_type = select_type
self.strict = strict
self.sorted_func = sorted_func
def preprocess(self, input_image):
"""
Preprocesses the input image before performing inference.
Returns:
image_data: Preprocessed image data ready for inference.
"""
img = read_img(input_image)
# Get the height and width of the input image
img_height, img_width = img.shape[:2]
# Resize the image to match the input shape
img = cv2.resize(img, (self.input_resolution[1], self.input_resolution[0]))
# Normalize the image data by dividing it by 255.0
image_data = np.array(img) / 255.0
# Transpose the image to have the channel dimension as the first dimension
image_data = np.transpose(image_data, (2, 0, 1)) # Channel first
# Expand the dimensions of the image data to match the expected input shape
# image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
image_data = image_data.astype(np.float32)
# Return the preprocessed image data
return image_data, np.array([img_height, img_width])
def postprocess(self, output, shape_raw, cat_id=[1]):
"""
Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs.
Args:
input_image (numpy.ndarray): The input image.
output (numpy.ndarray): The output of the model.
Returns:
numpy.ndarray: The input image with detections drawn on it.
"""
# Transpose and squeeze the output to match the expected shape
outputs = np.squeeze(output)
if len(outputs.shape) == 1:
outputs = outputs[None]
if output.shape[-1] != 6 and output.shape[1] == 84:
outputs = np.transpose(outputs)
# Get the number of rows in the outputs array
rows = outputs.shape[0]
# Calculate the scaling factors for the bounding box coordinates
x_factor = shape_raw[1] / self.input_width
y_factor = shape_raw[0] / self.input_height
# Lists to store the bounding boxes, scores, and class IDs of the detections
boxes = []
scores = []
class_ids = []
if outputs.shape[-1] == 6:
max_scores = outputs[:, 4]
classid = outputs[:, -1]
threshold_conf_masks = max_scores >= self.threshold_conf
classid_masks = classid[threshold_conf_masks] != 3.14159
max_scores = max_scores[threshold_conf_masks][classid_masks]
classid = classid[threshold_conf_masks][classid_masks]
boxes = outputs[:, :4][threshold_conf_masks][classid_masks]
boxes[:, [0, 2]] *= x_factor
boxes[:, [1, 3]] *= y_factor
boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
boxes = boxes.astype(np.int32)
else:
classes_scores = outputs[:, 4:]
max_scores = np.amax(classes_scores, -1)
threshold_conf_masks = max_scores >= self.threshold_conf
classid = np.argmax(classes_scores[threshold_conf_masks], -1)
classid_masks = classid!=3.14159
classes_scores = classes_scores[threshold_conf_masks][classid_masks]
max_scores = max_scores[threshold_conf_masks][classid_masks]
classid = classid[classid_masks]
xywh = outputs[:, :4][threshold_conf_masks][classid_masks]
x = xywh[:, 0:1]
y = xywh[:, 1:2]
w = xywh[:, 2:3]
h = xywh[:, 3:4]
left = ((x - w / 2) * x_factor)
top = ((y - h / 2) * y_factor)
width = (w * x_factor)
height = (h * y_factor)
boxes = np.concatenate([left, top, width, height], axis=-1).astype(np.int32)
boxes = boxes.tolist()
scores = max_scores.tolist()
class_ids = classid.tolist()
# Apply non-maximum suppression to filter out overlapping bounding boxes
indices = cv2.dnn.NMSBoxes(boxes, scores, self.threshold_conf, self.threshold_iou)
# Iterate over the selected indices after non-maximum suppression
results = []
for i in indices:
# Get the box, score, and class ID corresponding to the index
box = box_convert_simple(boxes[i], 'xywh2xyxy')
score = scores[i]
class_id = class_ids[i]
results.append(box + [score] + [class_id])
# # Draw the detection on the input image
# Return the modified input image
return np.array(results)
def process_results(self, results, shape_raw, cat_id=[1], single_person=True):
if isinstance(results, tuple):
det_results = results[0]
else:
det_results = results
person_results = []
person_count = 0
if len(results):
max_idx = -1
max_bbox_size = shape_raw[0] * shape_raw[1] * -10
max_bbox_shape = -1
bboxes = []
idx_list = []
for i in range(results.shape[0]):
bbox = results[i]
if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf):
idx_list.append(i)
bbox_shape = max((bbox[2] - bbox[0]), ((bbox[3] - bbox[1])))
if bbox_shape > max_bbox_shape:
max_bbox_shape = bbox_shape
results = results[idx_list]
for i in range(results.shape[0]):
bbox = results[i]
bboxes.append(bbox)
if self.select_type == 'max':
bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))
elif self.select_type == 'center':
bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1
bbox_shape = max((bbox[2] - bbox[0]), ((bbox[3] - bbox[1])))
if bbox_size > max_bbox_size:
if (self.strict or max_idx != -1) and bbox_shape < max_bbox_shape * self.threshold_bbox_shape_ratio:
continue
max_bbox_size = bbox_size
max_bbox_shape = bbox_shape
max_idx = i
if self.sorted_func is not None and len(bboxes) > 0:
max_idx = self.sorted_func(bboxes, shape_raw)
bbox = bboxes[max_idx]
if self.select_type == 'max':
max_bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))
elif self.select_type == 'center':
max_bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1
if max_idx != -1:
person_count = 1
if max_idx != -1:
person = {}
person['bbox'] = results[max_idx, :5]
person['track_id'] = int(0)
person_results.append(person)
for i in range(results.shape[0]):
bbox = results[i]
if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf):
if self.select_type == 'max':
bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))
elif self.select_type == 'center':
bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1
if i != max_idx and bbox_size > max_bbox_size * self.threshold_multi_persons and bbox_size < max_bbox_size:
person_count += 1
if not single_person:
person = {}
person['bbox'] = results[i, :5]
person['track_id'] = int(person_count - 1)
person_results.append(person)
return person_results
else:
return None
def postprocess_threading(self, outputs, shape_raw, person_results, i, single_person=True, **kwargs):
result = self.postprocess(outputs[i], shape_raw[i], cat_id=self.cat_id)
result = self.process_results(result, shape_raw[i], cat_id=self.cat_id, single_person=single_person)
if result is not None and len(result) != 0:
person_results[i] = result
def forward(self, img, shape_raw, **kwargs):
"""
Performs inference using an ONNX model and returns the output image with drawn detections.
Returns:
output_img: The output image with drawn detections.
"""
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
shape_raw = shape_raw.cpu().numpy()
outputs = self.session.run(None, {self.session.get_inputs()[0].name: img})[0]
person_results = [[{'bbox': np.array([0., 0., 1.*shape_raw[i][1], 1.*shape_raw[i][0], -1]), 'track_id': -1}] for i in range(len(outputs))]
for i in range(len(outputs)):
self.postprocess_threading(outputs, shape_raw, person_results, i, **kwargs)
return person_results
class ViTPose(SimpleOnnxInference):
def __init__(self, checkpoint, device='cuda', **kwargs):
super(ViTPose, self).__init__(checkpoint, device=device)
def forward(self, img, center, scale, **kwargs):
heatmaps = self.session.run([], {self.session.get_inputs()[0].name: img})[0]
points, prob = keypoints_from_heatmaps(heatmaps=heatmaps,
center=center,
scale=scale*200,
unbiased=True,
use_udp=False)
return np.concatenate([points, prob], axis=2)
@staticmethod
def preprocess(img, bbox=None, input_resolution=(256, 192), rescale=1.25, mask=None, **kwargs):
if bbox is None or bbox[-1] <= 0 or (bbox[2] - bbox[0]) < 10 or (bbox[3] - bbox[1]) < 10:
bbox = np.array([0, 0, img.shape[1], img.shape[0]])
bbox_xywh = bbox
if mask is not None:
img = np.where(mask>128, img, mask)
if isinstance(input_resolution, int):
center, scale = bbox_from_detector(bbox_xywh, (input_resolution, input_resolution), rescale=rescale)
img, new_shape, old_xy, new_xy = crop(img, center, scale, (input_resolution, input_resolution))
else:
center, scale = bbox_from_detector(bbox_xywh, input_resolution, rescale=rescale)
img, new_shape, old_xy, new_xy = crop(img, center, scale, (input_resolution[0], input_resolution[1]))
IMG_NORM_MEAN = np.array([0.485, 0.456, 0.406])
IMG_NORM_STD = np.array([0.229, 0.224, 0.225])
img_norm = (img / 255. - IMG_NORM_MEAN) / IMG_NORM_STD
img_norm = img_norm.transpose(2, 0, 1).astype(np.float32)
return img_norm, np.array(center), np.array(scale)
class Pose2d:
def __init__(self, checkpoint, detector_checkpoint=None, device='cuda', **kwargs):
if detector_checkpoint is not None:
self.detector = Yolo(detector_checkpoint, device)
else:
self.detector = None
self.model = ViTPose(checkpoint, device)
self.device = device
def load_images(self, inputs):
"""
Load images from various input types.
Args:
inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path,
single image array, or list of image arrays
Returns:
List[np.ndarray]: List of RGB image arrays
Raises:
ValueError: If file format is unsupported or image cannot be read
"""
if isinstance(inputs, str):
if inputs.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
cap = cv2.VideoCapture(inputs)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
images = frames
elif inputs.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
img = cv2.cvtColor(cv2.imread(inputs), cv2.COLOR_BGR2RGB)
if img is None:
raise ValueError(f"Cannot read image: {inputs}")
images = [img]
else:
raise ValueError(f"Unsupported file format: {inputs}")
elif isinstance(inputs, np.ndarray):
images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs]
elif isinstance(inputs, list):
images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs]
return images
def __call__(
self,
inputs: Union[str, np.ndarray, List[np.ndarray]],
return_image: bool = False,
**kwargs
):
"""
Process input and estimate 2D keypoints.
Args:
inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path,
single image array, or list of image arrays
**kwargs: Additional arguments for processing
Returns:
np.ndarray: Array of detected 2D keypoints for all input images
"""
images = self.load_images(inputs)
H, W = images[0].shape[:2]
if self.detector is not None:
bboxes = []
for _image in images:
img, shape = self.detector.preprocess(_image)
bboxes.append(self.detector(img[None], shape[None])[0][0]["bbox"])
else:
bboxes = [None] * len(images)
kp2ds = []
for _image, _bbox in zip(images, bboxes):
img, center, scale = self.model.preprocess(_image, _bbox)
kp2ds.append(self.model(img[None], center[None], scale[None]))
kp2ds = np.concatenate(kp2ds, 0)
metas = load_pose_metas_from_kp2ds_seq(kp2ds, width=W, height=H)
return metas
\ No newline at end of file
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import warnings
import cv2
import numpy as np
from typing import List
from PIL import Image
def box_convert_simple(box, convert_type='xyxy2xywh'):
if convert_type == 'xyxy2xywh':
return [box[0], box[1], box[2] - box[0], box[3] - box[1]]
elif convert_type == 'xywh2xyxy':
return [box[0], box[1], box[2] + box[0], box[3] + box[1]]
elif convert_type == 'xyxy2ctwh':
return [(box[0] + box[2]) / 2, (box[1] + box[3]) / 2, box[2] - box[0], box[3] - box[1]]
elif convert_type == 'ctwh2xyxy':
return [box[0] - box[2] // 2, box[1] - box[3] // 2, box[0] + (box[2] - box[2] // 2), box[1] + (box[3] - box[3] // 2)]
def read_img(image, convert='RGB', check_exist=False):
if isinstance(image, str):
if check_exist and not osp.exists(image):
return None
try:
img = Image.open(image)
if convert:
img = img.convert(convert)
except:
raise IOError('File error: ', image)
return np.asarray(img)
else:
if isinstance(image, np.ndarray):
if convert:
return image[..., ::-1]
else:
if convert:
img = img.convert(convert)
return np.asarray(img)
class AAPoseMeta:
def __init__(self, meta=None, kp2ds=None):
self.image_id = ""
self.height = 0
self.width = 0
self.kps_body: np.ndarray = None
self.kps_lhand: np.ndarray = None
self.kps_rhand: np.ndarray = None
self.kps_face: np.ndarray = None
self.kps_body_p: np.ndarray = None
self.kps_lhand_p: np.ndarray = None
self.kps_rhand_p: np.ndarray = None
self.kps_face_p: np.ndarray = None
if meta is not None:
self.load_from_meta(meta)
elif kp2ds is not None:
self.load_from_kp2ds(kp2ds)
def is_valid(self, kp, p, threshold):
x, y = kp
if x < 0 or y < 0 or x > self.width or y > self.height or p < threshold:
return False
else:
return True
def get_bbox(self, kp, kp_p, threshold=0.5):
kps = kp[kp_p > threshold]
if kps.size == 0:
return 0, 0, 0, 0
x0, y0 = kps.min(axis=0)
x1, y1 = kps.max(axis=0)
return x0, y0, x1, y1
def crop(self, x0, y0, x1, y1):
all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]
for kps in all_kps:
if kps is not None:
kps[:, 0] -= x0
kps[:, 1] -= y0
self.width = x1 - x0
self.height = y1 - y0
return self
def resize(self, width, height):
scale_x = width / self.width
scale_y = height / self.height
all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]
for kps in all_kps:
if kps is not None:
kps[:, 0] *= scale_x
kps[:, 1] *= scale_y
self.width = width
self.height = height
return self
def get_kps_body_with_p(self, normalize=False):
kps_body = self.kps_body.copy()
if normalize:
kps_body = kps_body / np.array([self.width, self.height])
return np.concatenate([kps_body, self.kps_body_p[:, None]])
@staticmethod
def from_kps_face(kps_face: np.ndarray, height: int, width: int):
pose_meta = AAPoseMeta()
pose_meta.kps_face = kps_face[:, :2]
if kps_face.shape[1] == 3:
pose_meta.kps_face_p = kps_face[:, 2]
else:
pose_meta.kps_face_p = kps_face[:, 0] * 0 + 1
pose_meta.height = height
pose_meta.width = width
return pose_meta
@staticmethod
def from_kps_body(kps_body: np.ndarray, height: int, width: int):
pose_meta = AAPoseMeta()
pose_meta.kps_body = kps_body[:, :2]
pose_meta.kps_body_p = kps_body[:, 2]
pose_meta.height = height
pose_meta.width = width
return pose_meta
@staticmethod
def from_humanapi_meta(meta):
pose_meta = AAPoseMeta()
width, height = meta["width"], meta["height"]
pose_meta.width = width
pose_meta.height = height
pose_meta.kps_body = meta["keypoints_body"][:, :2] * (width, height)
pose_meta.kps_body_p = meta["keypoints_body"][:, 2]
pose_meta.kps_lhand = meta["keypoints_left_hand"][:, :2] * (width, height)
pose_meta.kps_lhand_p = meta["keypoints_left_hand"][:, 2]
pose_meta.kps_rhand = meta["keypoints_right_hand"][:, :2] * (width, height)
pose_meta.kps_rhand_p = meta["keypoints_right_hand"][:, 2]
if 'keypoints_face' in meta:
pose_meta.kps_face = meta["keypoints_face"][:, :2] * (width, height)
pose_meta.kps_face_p = meta["keypoints_face"][:, 2]
return pose_meta
def load_from_meta(self, meta, norm_body=True, norm_hand=False):
self.image_id = meta.get("image_id", "00000.png")
self.height = meta["height"]
self.width = meta["width"]
kps_body_p = []
kps_body = []
for kp in meta["keypoints_body"]:
if kp is None:
kps_body.append([0, 0])
kps_body_p.append(0)
else:
kps_body.append(kp)
kps_body_p.append(1)
self.kps_body = np.array(kps_body)
self.kps_body[:, 0] *= self.width
self.kps_body[:, 1] *= self.height
self.kps_body_p = np.array(kps_body_p)
self.kps_lhand = np.array(meta["keypoints_left_hand"])[:, :2]
self.kps_lhand_p = np.array(meta["keypoints_left_hand"])[:, 2]
self.kps_rhand = np.array(meta["keypoints_right_hand"])[:, :2]
self.kps_rhand_p = np.array(meta["keypoints_right_hand"])[:, 2]
@staticmethod
def load_from_kp2ds(kp2ds: List[np.ndarray], width: int, height: int):
"""input 133x3 numpy keypoints and output AAPoseMeta
Args:
kp2ds (List[np.ndarray]): _description_
width (int): _description_
height (int): _description_
Returns:
_type_: _description_
"""
pose_meta = AAPoseMeta()
pose_meta.width = width
pose_meta.height = height
kps_body = (kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
kps_lhand = kp2ds[91:112]
kps_rhand = kp2ds[112:133]
kps_face = np.concatenate([kp2ds[23:23+68], kp2ds[1:3]], axis=0)
pose_meta.kps_body = kps_body[:, :2]
pose_meta.kps_body_p = kps_body[:, 2]
pose_meta.kps_lhand = kps_lhand[:, :2]
pose_meta.kps_lhand_p = kps_lhand[:, 2]
pose_meta.kps_rhand = kps_rhand[:, :2]
pose_meta.kps_rhand_p = kps_rhand[:, 2]
pose_meta.kps_face = kps_face[:, :2]
pose_meta.kps_face_p = kps_face[:, 2]
return pose_meta
@staticmethod
def from_dwpose(dwpose_det_res, height, width):
pose_meta = AAPoseMeta()
pose_meta.kps_body = dwpose_det_res["bodies"]["candidate"]
pose_meta.kps_body_p = dwpose_det_res["bodies"]["score"]
pose_meta.kps_body[:, 0] *= width
pose_meta.kps_body[:, 1] *= height
pose_meta.kps_lhand, pose_meta.kps_rhand = dwpose_det_res["hands"]
pose_meta.kps_lhand[:, 0] *= width
pose_meta.kps_lhand[:, 1] *= height
pose_meta.kps_rhand[:, 0] *= width
pose_meta.kps_rhand[:, 1] *= height
pose_meta.kps_lhand_p, pose_meta.kps_rhand_p = dwpose_det_res["hands_score"]
pose_meta.kps_face = dwpose_det_res["faces"][0]
pose_meta.kps_face[:, 0] *= width
pose_meta.kps_face[:, 1] *= height
pose_meta.kps_face_p = dwpose_det_res["faces_score"][0]
return pose_meta
def save_json(self):
pass
def draw_aapose(self, img, threshold=0.5, stick_width_norm=200, draw_hand=True, draw_head=True):
from .human_visualization import draw_aapose_by_meta
return draw_aapose_by_meta(img, self, threshold, stick_width_norm, draw_hand, draw_head)
def translate(self, x0, y0):
all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]
for kps in all_kps:
if kps is not None:
kps[:, 0] -= x0
kps[:, 1] -= y0
def scale(self, sx, sy):
all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]
for kps in all_kps:
if kps is not None:
kps[:, 0] *= sx
kps[:, 1] *= sy
def padding_resize2(self, height=512, width=512):
"""kps will be changed inplace
"""
all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]
ori_height, ori_width = self.height, self.width
if (ori_height / ori_width) > (height / width):
new_width = int(height / ori_height * ori_width)
padding = int((width - new_width) / 2)
padding_width = padding
padding_height = 0
scale = height / ori_height
for kps in all_kps:
if kps is not None:
kps[:, 0] = kps[:, 0] * scale + padding
kps[:, 1] = kps[:, 1] * scale
else:
new_height = int(width / ori_width * ori_height)
padding = int((height - new_height) / 2)
padding_width = 0
padding_height = padding
scale = width / ori_width
for kps in all_kps:
if kps is not None:
kps[:, 1] = kps[:, 1] * scale + padding
kps[:, 0] = kps[:, 0] * scale
self.width = width
self.height = height
return self
def transform_preds(coords, center, scale, output_size, use_udp=False):
"""Get final keypoint predictions from heatmaps and apply scaling and
translation to map them back to the image.
Note:
num_keypoints: K
Args:
coords (np.ndarray[K, ndims]):
* If ndims=2, corrds are predicted keypoint location.
* If ndims=4, corrds are composed of (x, y, scores, tags)
* If ndims=5, corrds are composed of (x, y, scores, tags,
flipped_tags)
center (np.ndarray[2, ]): Center of the bounding box (x, y).
scale (np.ndarray[2, ]): Scale of the bounding box
wrt [width, height].
output_size (np.ndarray[2, ] | list(2,)): Size of the
destination heatmaps.
use_udp (bool): Use unbiased data processing
Returns:
np.ndarray: Predicted coordinates in the images.
"""
assert coords.shape[1] in (2, 4, 5)
assert len(center) == 2
assert len(scale) == 2
assert len(output_size) == 2
# Recover the scale which is normalized by a factor of 200.
# scale = scale * 200.0
if use_udp:
scale_x = scale[0] / (output_size[0] - 1.0)
scale_y = scale[1] / (output_size[1] - 1.0)
else:
scale_x = scale[0] / output_size[0]
scale_y = scale[1] / output_size[1]
target_coords = np.ones_like(coords)
target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5
target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5
return target_coords
def _calc_distances(preds, targets, mask, normalize):
"""Calculate the normalized distances between preds and target.
Note:
batch_size: N
num_keypoints: K
dimension of keypoints: D (normally, D=2 or D=3)
Args:
preds (np.ndarray[N, K, D]): Predicted keypoint location.
targets (np.ndarray[N, K, D]): Groundtruth keypoint location.
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
joints, and True for visible. Invisible joints will be ignored for
accuracy calculation.
normalize (np.ndarray[N, D]): Typical value is heatmap_size
Returns:
np.ndarray[K, N]: The normalized distances. \
If target keypoints are missing, the distance is -1.
"""
N, K, _ = preds.shape
# set mask=0 when normalize==0
_mask = mask.copy()
_mask[np.where((normalize == 0).sum(1))[0], :] = False
distances = np.full((N, K), -1, dtype=np.float32)
# handle invalid values
normalize[np.where(normalize <= 0)] = 1e6
distances[_mask] = np.linalg.norm(
((preds - targets) / normalize[:, None, :])[_mask], axis=-1)
return distances.T
def _distance_acc(distances, thr=0.5):
"""Return the percentage below the distance threshold, while ignoring
distances values with -1.
Note:
batch_size: N
Args:
distances (np.ndarray[N, ]): The normalized distances.
thr (float): Threshold of the distances.
Returns:
float: Percentage of distances below the threshold. \
If all target keypoints are missing, return -1.
"""
distance_valid = distances != -1
num_distance_valid = distance_valid.sum()
if num_distance_valid > 0:
return (distances[distance_valid] < thr).sum() / num_distance_valid
return -1
def _get_max_preds(heatmaps):
"""Get keypoint predictions from score maps.
Note:
batch_size: N
num_keypoints: K
heatmap height: H
heatmap width: W
Args:
heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps.
Returns:
tuple: A tuple containing aggregated results.
- preds (np.ndarray[N, K, 2]): Predicted keypoint location.
- maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
"""
assert isinstance(heatmaps,
np.ndarray), ('heatmaps should be numpy.ndarray')
assert heatmaps.ndim == 4, 'batch_images should be 4-ndim'
N, K, _, W = heatmaps.shape
heatmaps_reshaped = heatmaps.reshape((N, K, -1))
idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1))
maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1))
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
preds[:, :, 0] = preds[:, :, 0] % W
preds[:, :, 1] = preds[:, :, 1] // W
preds = np.where(np.tile(maxvals, (1, 1, 2)) > 0.0, preds, -1)
return preds, maxvals
def _get_max_preds_3d(heatmaps):
"""Get keypoint predictions from 3D score maps.
Note:
batch size: N
num keypoints: K
heatmap depth size: D
heatmap height: H
heatmap width: W
Args:
heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps.
Returns:
tuple: A tuple containing aggregated results.
- preds (np.ndarray[N, K, 3]): Predicted keypoint location.
- maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
"""
assert isinstance(heatmaps, np.ndarray), \
('heatmaps should be numpy.ndarray')
assert heatmaps.ndim == 5, 'heatmaps should be 5-ndim'
N, K, D, H, W = heatmaps.shape
heatmaps_reshaped = heatmaps.reshape((N, K, -1))
idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1))
maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1))
preds = np.zeros((N, K, 3), dtype=np.float32)
_idx = idx[..., 0]
preds[..., 2] = _idx // (H * W)
preds[..., 1] = (_idx // W) % H
preds[..., 0] = _idx % W
preds = np.where(maxvals > 0.0, preds, -1)
return preds, maxvals
def pose_pck_accuracy(output, target, mask, thr=0.05, normalize=None):
"""Calculate the pose accuracy of PCK for each individual keypoint and the
averaged accuracy across all keypoints from heatmaps.
Note:
PCK metric measures accuracy of the localization of the body joints.
The distances between predicted positions and the ground-truth ones
are typically normalized by the bounding box size.
The threshold (thr) of the normalized distance is commonly set
as 0.05, 0.1 or 0.2 etc.
- batch_size: N
- num_keypoints: K
- heatmap height: H
- heatmap width: W
Args:
output (np.ndarray[N, K, H, W]): Model output heatmaps.
target (np.ndarray[N, K, H, W]): Groundtruth heatmaps.
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
joints, and True for visible. Invisible joints will be ignored for
accuracy calculation.
thr (float): Threshold of PCK calculation. Default 0.05.
normalize (np.ndarray[N, 2]): Normalization factor for H&W.
Returns:
tuple: A tuple containing keypoint accuracy.
- np.ndarray[K]: Accuracy of each keypoint.
- float: Averaged accuracy across all keypoints.
- int: Number of valid keypoints.
"""
N, K, H, W = output.shape
if K == 0:
return None, 0, 0
if normalize is None:
normalize = np.tile(np.array([[H, W]]), (N, 1))
pred, _ = _get_max_preds(output)
gt, _ = _get_max_preds(target)
return keypoint_pck_accuracy(pred, gt, mask, thr, normalize)
def keypoint_pck_accuracy(pred, gt, mask, thr, normalize):
"""Calculate the pose accuracy of PCK for each individual keypoint and the
averaged accuracy across all keypoints for coordinates.
Note:
PCK metric measures accuracy of the localization of the body joints.
The distances between predicted positions and the ground-truth ones
are typically normalized by the bounding box size.
The threshold (thr) of the normalized distance is commonly set
as 0.05, 0.1 or 0.2 etc.
- batch_size: N
- num_keypoints: K
Args:
pred (np.ndarray[N, K, 2]): Predicted keypoint location.
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
joints, and True for visible. Invisible joints will be ignored for
accuracy calculation.
thr (float): Threshold of PCK calculation.
normalize (np.ndarray[N, 2]): Normalization factor for H&W.
Returns:
tuple: A tuple containing keypoint accuracy.
- acc (np.ndarray[K]): Accuracy of each keypoint.
- avg_acc (float): Averaged accuracy across all keypoints.
- cnt (int): Number of valid keypoints.
"""
distances = _calc_distances(pred, gt, mask, normalize)
acc = np.array([_distance_acc(d, thr) for d in distances])
valid_acc = acc[acc >= 0]
cnt = len(valid_acc)
avg_acc = valid_acc.mean() if cnt > 0 else 0
return acc, avg_acc, cnt
def keypoint_auc(pred, gt, mask, normalize, num_step=20):
"""Calculate the pose accuracy of PCK for each individual keypoint and the
averaged accuracy across all keypoints for coordinates.
Note:
- batch_size: N
- num_keypoints: K
Args:
pred (np.ndarray[N, K, 2]): Predicted keypoint location.
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
joints, and True for visible. Invisible joints will be ignored for
accuracy calculation.
normalize (float): Normalization factor.
Returns:
float: Area under curve.
"""
nor = np.tile(np.array([[normalize, normalize]]), (pred.shape[0], 1))
x = [1.0 * i / num_step for i in range(num_step)]
y = []
for thr in x:
_, avg_acc, _ = keypoint_pck_accuracy(pred, gt, mask, thr, nor)
y.append(avg_acc)
auc = 0
for i in range(num_step):
auc += 1.0 / num_step * y[i]
return auc
def keypoint_nme(pred, gt, mask, normalize_factor):
"""Calculate the normalized mean error (NME).
Note:
- batch_size: N
- num_keypoints: K
Args:
pred (np.ndarray[N, K, 2]): Predicted keypoint location.
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
joints, and True for visible. Invisible joints will be ignored for
accuracy calculation.
normalize_factor (np.ndarray[N, 2]): Normalization factor.
Returns:
float: normalized mean error
"""
distances = _calc_distances(pred, gt, mask, normalize_factor)
distance_valid = distances[distances != -1]
return distance_valid.sum() / max(1, len(distance_valid))
def keypoint_epe(pred, gt, mask):
"""Calculate the end-point error.
Note:
- batch_size: N
- num_keypoints: K
Args:
pred (np.ndarray[N, K, 2]): Predicted keypoint location.
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
joints, and True for visible. Invisible joints will be ignored for
accuracy calculation.
Returns:
float: Average end-point error.
"""
distances = _calc_distances(
pred, gt, mask,
np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32))
distance_valid = distances[distances != -1]
return distance_valid.sum() / max(1, len(distance_valid))
def _taylor(heatmap, coord):
"""Distribution aware coordinate decoding method.
Note:
- heatmap height: H
- heatmap width: W
Args:
heatmap (np.ndarray[H, W]): Heatmap of a particular joint type.
coord (np.ndarray[2,]): Coordinates of the predicted keypoints.
Returns:
np.ndarray[2,]: Updated coordinates.
"""
H, W = heatmap.shape[:2]
px, py = int(coord[0]), int(coord[1])
if 1 < px < W - 2 and 1 < py < H - 2:
dx = 0.5 * (heatmap[py][px + 1] - heatmap[py][px - 1])
dy = 0.5 * (heatmap[py + 1][px] - heatmap[py - 1][px])
dxx = 0.25 * (
heatmap[py][px + 2] - 2 * heatmap[py][px] + heatmap[py][px - 2])
dxy = 0.25 * (
heatmap[py + 1][px + 1] - heatmap[py - 1][px + 1] -
heatmap[py + 1][px - 1] + heatmap[py - 1][px - 1])
dyy = 0.25 * (
heatmap[py + 2 * 1][px] - 2 * heatmap[py][px] +
heatmap[py - 2 * 1][px])
derivative = np.array([[dx], [dy]])
hessian = np.array([[dxx, dxy], [dxy, dyy]])
if dxx * dyy - dxy**2 != 0:
hessianinv = np.linalg.inv(hessian)
offset = -hessianinv @ derivative
offset = np.squeeze(np.array(offset.T), axis=0)
coord += offset
return coord
def post_dark_udp(coords, batch_heatmaps, kernel=3):
"""DARK post-pocessing. Implemented by udp. Paper ref: Huang et al. The
Devil is in the Details: Delving into Unbiased Data Processing for Human
Pose Estimation (CVPR 2020). Zhang et al. Distribution-Aware Coordinate
Representation for Human Pose Estimation (CVPR 2020).
Note:
- batch size: B
- num keypoints: K
- num persons: N
- height of heatmaps: H
- width of heatmaps: W
B=1 for bottom_up paradigm where all persons share the same heatmap.
B=N for top_down paradigm where each person has its own heatmaps.
Args:
coords (np.ndarray[N, K, 2]): Initial coordinates of human pose.
batch_heatmaps (np.ndarray[B, K, H, W]): batch_heatmaps
kernel (int): Gaussian kernel size (K) for modulation.
Returns:
np.ndarray([N, K, 2]): Refined coordinates.
"""
if not isinstance(batch_heatmaps, np.ndarray):
batch_heatmaps = batch_heatmaps.cpu().numpy()
B, K, H, W = batch_heatmaps.shape
N = coords.shape[0]
assert (B == 1 or B == N)
for heatmaps in batch_heatmaps:
for heatmap in heatmaps:
cv2.GaussianBlur(heatmap, (kernel, kernel), 0, heatmap)
np.clip(batch_heatmaps, 0.001, 50, batch_heatmaps)
np.log(batch_heatmaps, batch_heatmaps)
batch_heatmaps_pad = np.pad(
batch_heatmaps, ((0, 0), (0, 0), (1, 1), (1, 1)),
mode='edge').flatten()
index = coords[..., 0] + 1 + (coords[..., 1] + 1) * (W + 2)
index += (W + 2) * (H + 2) * np.arange(0, B * K).reshape(-1, K)
index = index.astype(int).reshape(-1, 1)
i_ = batch_heatmaps_pad[index]
ix1 = batch_heatmaps_pad[index + 1]
iy1 = batch_heatmaps_pad[index + W + 2]
ix1y1 = batch_heatmaps_pad[index + W + 3]
ix1_y1_ = batch_heatmaps_pad[index - W - 3]
ix1_ = batch_heatmaps_pad[index - 1]
iy1_ = batch_heatmaps_pad[index - 2 - W]
dx = 0.5 * (ix1 - ix1_)
dy = 0.5 * (iy1 - iy1_)
derivative = np.concatenate([dx, dy], axis=1)
derivative = derivative.reshape(N, K, 2, 1)
dxx = ix1 - 2 * i_ + ix1_
dyy = iy1 - 2 * i_ + iy1_
dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1)
hessian = hessian.reshape(N, K, 2, 2)
hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
coords -= np.einsum('ijmn,ijnk->ijmk', hessian, derivative).squeeze()
return coords
def _gaussian_blur(heatmaps, kernel=11):
"""Modulate heatmap distribution with Gaussian.
sigma = 0.3*((kernel_size-1)*0.5-1)+0.8
sigma~=3 if k=17
sigma=2 if k=11;
sigma~=1.5 if k=7;
sigma~=1 if k=3;
Note:
- batch_size: N
- num_keypoints: K
- heatmap height: H
- heatmap width: W
Args:
heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps.
kernel (int): Gaussian kernel size (K) for modulation, which should
match the heatmap gaussian sigma when training.
K=17 for sigma=3 and k=11 for sigma=2.
Returns:
np.ndarray ([N, K, H, W]): Modulated heatmap distribution.
"""
assert kernel % 2 == 1
border = (kernel - 1) // 2
batch_size = heatmaps.shape[0]
num_joints = heatmaps.shape[1]
height = heatmaps.shape[2]
width = heatmaps.shape[3]
for i in range(batch_size):
for j in range(num_joints):
origin_max = np.max(heatmaps[i, j])
dr = np.zeros((height + 2 * border, width + 2 * border),
dtype=np.float32)
dr[border:-border, border:-border] = heatmaps[i, j].copy()
dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
heatmaps[i, j] = dr[border:-border, border:-border].copy()
heatmaps[i, j] *= origin_max / np.max(heatmaps[i, j])
return heatmaps
def keypoints_from_regression(regression_preds, center, scale, img_size):
"""Get final keypoint predictions from regression vectors and transform
them back to the image.
Note:
- batch_size: N
- num_keypoints: K
Args:
regression_preds (np.ndarray[N, K, 2]): model prediction.
center (np.ndarray[N, 2]): Center of the bounding box (x, y).
scale (np.ndarray[N, 2]): Scale of the bounding box
wrt height/width.
img_size (list(img_width, img_height)): model input image size.
Returns:
tuple:
- preds (np.ndarray[N, K, 2]): Predicted keypoint location in images.
- maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
"""
N, K, _ = regression_preds.shape
preds, maxvals = regression_preds, np.ones((N, K, 1), dtype=np.float32)
preds = preds * img_size
# Transform back to the image
for i in range(N):
preds[i] = transform_preds(preds[i], center[i], scale[i], img_size)
return preds, maxvals
def keypoints_from_heatmaps(heatmaps,
center,
scale,
unbiased=False,
post_process='default',
kernel=11,
valid_radius_factor=0.0546875,
use_udp=False,
target_type='GaussianHeatmap'):
"""Get final keypoint predictions from heatmaps and transform them back to
the image.
Note:
- batch size: N
- num keypoints: K
- heatmap height: H
- heatmap width: W
Args:
heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps.
center (np.ndarray[N, 2]): Center of the bounding box (x, y).
scale (np.ndarray[N, 2]): Scale of the bounding box
wrt height/width.
post_process (str/None): Choice of methods to post-process
heatmaps. Currently supported: None, 'default', 'unbiased',
'megvii'.
unbiased (bool): Option to use unbiased decoding. Mutually
exclusive with megvii.
Note: this arg is deprecated and unbiased=True can be replaced
by post_process='unbiased'
Paper ref: Zhang et al. Distribution-Aware Coordinate
Representation for Human Pose Estimation (CVPR 2020).
kernel (int): Gaussian kernel size (K) for modulation, which should
match the heatmap gaussian sigma when training.
K=17 for sigma=3 and k=11 for sigma=2.
valid_radius_factor (float): The radius factor of the positive area
in classification heatmap for UDP.
use_udp (bool): Use unbiased data processing.
target_type (str): 'GaussianHeatmap' or 'CombinedTarget'.
GaussianHeatmap: Classification target with gaussian distribution.
CombinedTarget: The combination of classification target
(response map) and regression target (offset map).
Paper ref: Huang et al. The Devil is in the Details: Delving into
Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
Returns:
tuple: A tuple containing keypoint predictions and scores.
- preds (np.ndarray[N, K, 2]): Predicted keypoint location in images.
- maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
"""
# Avoid being affected
heatmaps = heatmaps.copy()
# detect conflicts
if unbiased:
assert post_process not in [False, None, 'megvii']
if post_process in ['megvii', 'unbiased']:
assert kernel > 0
if use_udp:
assert not post_process == 'megvii'
# normalize configs
if post_process is False:
warnings.warn(
'post_process=False is deprecated, '
'please use post_process=None instead', DeprecationWarning)
post_process = None
elif post_process is True:
if unbiased is True:
warnings.warn(
'post_process=True, unbiased=True is deprecated,'
" please use post_process='unbiased' instead",
DeprecationWarning)
post_process = 'unbiased'
else:
warnings.warn(
'post_process=True, unbiased=False is deprecated, '
"please use post_process='default' instead",
DeprecationWarning)
post_process = 'default'
elif post_process == 'default':
if unbiased is True:
warnings.warn(
'unbiased=True is deprecated, please use '
"post_process='unbiased' instead", DeprecationWarning)
post_process = 'unbiased'
# start processing
if post_process == 'megvii':
heatmaps = _gaussian_blur(heatmaps, kernel=kernel)
N, K, H, W = heatmaps.shape
if use_udp:
if target_type.lower() == 'GaussianHeatMap'.lower():
preds, maxvals = _get_max_preds(heatmaps)
preds = post_dark_udp(preds, heatmaps, kernel=kernel)
elif target_type.lower() == 'CombinedTarget'.lower():
for person_heatmaps in heatmaps:
for i, heatmap in enumerate(person_heatmaps):
kt = 2 * kernel + 1 if i % 3 == 0 else kernel
cv2.GaussianBlur(heatmap, (kt, kt), 0, heatmap)
# valid radius is in direct proportion to the height of heatmap.
valid_radius = valid_radius_factor * H
offset_x = heatmaps[:, 1::3, :].flatten() * valid_radius
offset_y = heatmaps[:, 2::3, :].flatten() * valid_radius
heatmaps = heatmaps[:, ::3, :]
preds, maxvals = _get_max_preds(heatmaps)
index = preds[..., 0] + preds[..., 1] * W
index += W * H * np.arange(0, N * K / 3)
index = index.astype(int).reshape(N, K // 3, 1)
preds += np.concatenate((offset_x[index], offset_y[index]), axis=2)
else:
raise ValueError('target_type should be either '
"'GaussianHeatmap' or 'CombinedTarget'")
else:
preds, maxvals = _get_max_preds(heatmaps)
if post_process == 'unbiased': # alleviate biased coordinate
# apply Gaussian distribution modulation.
heatmaps = np.log(
np.maximum(_gaussian_blur(heatmaps, kernel), 1e-10))
for n in range(N):
for k in range(K):
preds[n][k] = _taylor(heatmaps[n][k], preds[n][k])
elif post_process is not None:
# add +/-0.25 shift to the predicted locations for higher acc.
for n in range(N):
for k in range(K):
heatmap = heatmaps[n][k]
px = int(preds[n][k][0])
py = int(preds[n][k][1])
if 1 < px < W - 1 and 1 < py < H - 1:
diff = np.array([
heatmap[py][px + 1] - heatmap[py][px - 1],
heatmap[py + 1][px] - heatmap[py - 1][px]
])
preds[n][k] += np.sign(diff) * .25
if post_process == 'megvii':
preds[n][k] += 0.5
# Transform back to the image
for i in range(N):
preds[i] = transform_preds(
preds[i], center[i], scale[i], [W, H], use_udp=use_udp)
if post_process == 'megvii':
maxvals = maxvals / 255.0 + 0.5
return preds, maxvals
def keypoints_from_heatmaps3d(heatmaps, center, scale):
"""Get final keypoint predictions from 3d heatmaps and transform them back
to the image.
Note:
- batch size: N
- num keypoints: K
- heatmap depth size: D
- heatmap height: H
- heatmap width: W
Args:
heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps.
center (np.ndarray[N, 2]): Center of the bounding box (x, y).
scale (np.ndarray[N, 2]): Scale of the bounding box
wrt height/width.
Returns:
tuple: A tuple containing keypoint predictions and scores.
- preds (np.ndarray[N, K, 3]): Predicted 3d keypoint location \
in images.
- maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
"""
N, K, D, H, W = heatmaps.shape
preds, maxvals = _get_max_preds_3d(heatmaps)
# Transform back to the image
for i in range(N):
preds[i, :, :2] = transform_preds(preds[i, :, :2], center[i], scale[i],
[W, H])
return preds, maxvals
def multilabel_classification_accuracy(pred, gt, mask, thr=0.5):
"""Get multi-label classification accuracy.
Note:
- batch size: N
- label number: L
Args:
pred (np.ndarray[N, L, 2]): model predicted labels.
gt (np.ndarray[N, L, 2]): ground-truth labels.
mask (np.ndarray[N, 1] or np.ndarray[N, L] ): reliability of
ground-truth labels.
Returns:
float: multi-label classification accuracy.
"""
# we only compute accuracy on the samples with ground-truth of all labels.
valid = (mask > 0).min(axis=1) if mask.ndim == 2 else (mask > 0)
pred, gt = pred[valid], gt[valid]
if pred.shape[0] == 0:
acc = 0.0 # when no sample is with gt labels, set acc to 0.
else:
# The classification of a sample is regarded as correct
# only if it's correct for all labels.
acc = (((pred - thr) * (gt - thr)) > 0).all(axis=1).mean()
return acc
def get_transform(center, scale, res, rot=0):
"""Generate transformation matrix."""
# res: (height, width), (rows, cols)
crop_aspect_ratio = res[0] / float(res[1])
h = 200 * scale
w = h / crop_aspect_ratio
t = np.zeros((3, 3))
t[0, 0] = float(res[1]) / w
t[1, 1] = float(res[0]) / h
t[0, 2] = res[1] * (-float(center[0]) / w + .5)
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
t[2, 2] = 1
if not rot == 0:
rot = -rot # To match direction of rotation from cropping
rot_mat = np.zeros((3, 3))
rot_rad = rot * np.pi / 180
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0, :2] = [cs, -sn]
rot_mat[1, :2] = [sn, cs]
rot_mat[2, 2] = 1
# Need to rotate around center
t_mat = np.eye(3)
t_mat[0, 2] = -res[1] / 2
t_mat[1, 2] = -res[0] / 2
t_inv = t_mat.copy()
t_inv[:2, 2] *= -1
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
return t
def transform(pt, center, scale, res, invert=0, rot=0):
"""Transform pixel location to different reference."""
t = get_transform(center, scale, res, rot=rot)
if invert:
t = np.linalg.inv(t)
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
new_pt = np.dot(t, new_pt)
return np.array([round(new_pt[0]), round(new_pt[1])], dtype=int) + 1
def bbox_from_detector(bbox, input_resolution=(224, 224), rescale=1.25):
"""
Get center and scale of bounding box from bounding box.
The expected format is [min_x, min_y, max_x, max_y].
"""
CROP_IMG_HEIGHT, CROP_IMG_WIDTH = input_resolution
CROP_ASPECT_RATIO = CROP_IMG_HEIGHT / float(CROP_IMG_WIDTH)
# center
center_x = (bbox[0] + bbox[2]) / 2.0
center_y = (bbox[1] + bbox[3]) / 2.0
center = np.array([center_x, center_y])
# scale
bbox_w = bbox[2] - bbox[0]
bbox_h = bbox[3] - bbox[1]
bbox_size = max(bbox_w * CROP_ASPECT_RATIO, bbox_h)
scale = np.array([bbox_size / CROP_ASPECT_RATIO, bbox_size]) / 200.0
# scale = bbox_size / 200.0
# adjust bounding box tightness
scale *= rescale
return center, scale
def crop(img, center, scale, res):
"""
Crop image according to the supplied bounding box.
res: [rows, cols]
"""
# Upper left point
ul = np.array(transform([1, 1], center, max(scale), res, invert=1)) - 1
# Bottom right point
br = np.array(transform([res[1] + 1, res[0] + 1], center, max(scale), res, invert=1)) - 1
# Padding so that when rotated proper amount of context is included
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
new_shape = [br[1] - ul[1], br[0] - ul[0]]
if len(img.shape) > 2:
new_shape += [img.shape[2]]
new_img = np.zeros(new_shape, dtype=np.float32)
# Range to fill new array
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
# Range to sample from original image
old_x = max(0, ul[0]), min(len(img[0]), br[0])
old_y = max(0, ul[1]), min(len(img), br[1])
try:
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
except Exception as e:
print(e)
new_img = cv2.resize(new_img, (res[1], res[0])) # (cols, rows)
return new_img, new_shape, (old_x, old_y), (new_x, new_y) # , ul, br
def split_kp2ds_for_aa(kp2ds, ret_face=False):
kp2ds_body = (kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
kp2ds_lhand = kp2ds[91:112]
kp2ds_rhand = kp2ds[112:133]
kp2ds_face = kp2ds[22:91]
if ret_face:
return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy(), kp2ds_face.copy()
return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy()
def load_pose_metas_from_kp2ds_seq_list(kp2ds_seq, width, height):
metas = []
for kps in kp2ds_seq:
if len(kps) != 1:
return None
kps = kps[0].copy()
kps[:, 0] /= width
kps[:, 1] /= height
kp2ds_body, kp2ds_lhand, kp2ds_rhand, kp2ds_face = split_kp2ds_for_aa(kps, ret_face=True)
if kp2ds_body[:, :2].min(axis=1).max() < 0:
kp2ds_body = last_kp2ds_body
last_kp2ds_body = kp2ds_body
meta = {
"width": width,
"height": height,
"keypoints_body": kp2ds_body.tolist(),
"keypoints_left_hand": kp2ds_lhand.tolist(),
"keypoints_right_hand": kp2ds_rhand.tolist(),
"keypoints_face": kp2ds_face.tolist(),
}
metas.append(meta)
return metas
def load_pose_metas_from_kp2ds_seq(kp2ds_seq, width, height):
metas = []
for kps in kp2ds_seq:
kps = kps.copy()
kps[:, 0] /= width
kps[:, 1] /= height
kp2ds_body, kp2ds_lhand, kp2ds_rhand, kp2ds_face = split_kp2ds_for_aa(kps, ret_face=True)
# 排除全部小于0的情况
if kp2ds_body[:, :2].min(axis=1).max() < 0:
kp2ds_body = last_kp2ds_body
last_kp2ds_body = kp2ds_body
meta = {
"width": width,
"height": height,
"keypoints_body": kp2ds_body,
"keypoints_left_hand": kp2ds_lhand,
"keypoints_right_hand": kp2ds_rhand,
"keypoints_face": kp2ds_face,
}
metas.append(meta)
return metas
\ No newline at end of file
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import argparse
from process_pipepline import ProcessPipeline
def _parse_args():
parser = argparse.ArgumentParser(
description="The preprocessing pipeline for Wan-animate."
)
parser.add_argument(
"--ckpt_path",
type=str,
default=None,
help="The path to the preprocessing model's checkpoint directory. ")
parser.add_argument(
"--video_path",
type=str,
default=None,
help="The path to the driving video.")
parser.add_argument(
"--refer_path",
type=str,
default=None,
help="The path to the refererence image.")
parser.add_argument(
"--save_path",
type=str,
default=None,
help="The path to save the processed results.")
parser.add_argument(
"--resolution_area",
type=int,
nargs=2,
default=[1280, 720],
help="The target resolution for processing, specified as [width, height]. To handle different aspect ratios, the video is resized to have a total area equivalent to width * height, while preserving the original aspect ratio."
)
parser.add_argument(
"--fps",
type=int,
default=30,
help="The target FPS for processing the driving video. Set to -1 to use the video's original FPS."
)
parser.add_argument(
"--replace_flag",
action="store_true",
default=False,
help="Whether to use replacement mode.")
parser.add_argument(
"--retarget_flag",
action="store_true",
default=False,
help="Whether to use pose retargeting. Currently only supported in animation mode")
parser.add_argument(
"--use_flux",
action="store_true",
default=False,
help="Whether to use image editing in pose retargeting. Recommended if the character in the reference image or the first frame of the driving video is not in a standard, front-facing pose")
# Parameters for the mask strategy in replacement mode. These control the mask's size and shape. Refer to https://arxiv.org/pdf/2502.06145
parser.add_argument(
"--iterations",
type=int,
default=3,
help="Number of iterations for mask dilation."
)
parser.add_argument(
"--k",
type=int,
default=7,
help="Number of kernel size for mask dilation."
)
parser.add_argument(
"--w_len",
type=int,
default=1,
help="The number of subdivisions for the grid along the 'w' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed."
)
parser.add_argument(
"--h_len",
type=int,
default=1,
help="The number of subdivisions for the grid along the 'h' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed."
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = _parse_args()
args_dict = vars(args)
print(args_dict)
assert len(args.resolution_area) == 2, "resolution_area should be a list of two integers [width, height]"
assert not args.use_flux or args.retarget_flag, "Image editing with FLUX can only be used when pose retargeting is enabled."
pose2d_checkpoint_path = os.path.join(args.ckpt_path, 'pose2d/vitpose_h_wholebody.onnx')
det_checkpoint_path = os.path.join(args.ckpt_path, 'det/yolov10m.onnx')
sam2_checkpoint_path = os.path.join(args.ckpt_path, 'sam2/sam2_hiera_large.pt') if args.replace_flag else None
flux_kontext_path = os.path.join(args.ckpt_path, 'FLUX.1-Kontext-dev') if args.use_flux else None
process_pipeline = ProcessPipeline(det_checkpoint_path=det_checkpoint_path, pose2d_checkpoint_path=pose2d_checkpoint_path, sam_checkpoint_path=sam2_checkpoint_path, flux_kontext_path=flux_kontext_path)
os.makedirs(args.save_path, exist_ok=True)
process_pipeline(video_path=args.video_path,
refer_image_path=args.refer_path,
output_path=args.save_path,
resolution_area=args.resolution_area,
fps=args.fps,
iterations=args.iterations,
k=args.k,
w_len=args.w_len,
h_len=args.h_len,
retarget_flag=args.retarget_flag,
use_flux=args.use_flux,
replace_flag=args.replace_flag)
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import numpy as np
import shutil
import torch
from diffusers import FluxKontextPipeline
import cv2
from loguru import logger
from PIL import Image
try:
import moviepy.editor as mpy
except:
import moviepy as mpy
from decord import VideoReader
from pose2d import Pose2d
from pose2d_utils import AAPoseMeta
from utils import resize_by_area, get_frame_indices, padding_resize, get_face_bboxes, get_aug_mask, get_mask_body_img
from human_visualization import draw_aapose_by_meta_new
from retarget_pose import get_retarget_pose
import sam2.modeling.sam.transformer as transformer
transformer.USE_FLASH_ATTN = False
transformer.MATH_KERNEL_ON = True
transformer.OLD_GPU = True
from sam_utils import build_sam2_video_predictor
class ProcessPipeline():
def __init__(self, det_checkpoint_path, pose2d_checkpoint_path, sam_checkpoint_path, flux_kontext_path):
self.pose2d = Pose2d(checkpoint=pose2d_checkpoint_path, detector_checkpoint=det_checkpoint_path)
model_cfg = "sam2_hiera_l.yaml"
if sam_checkpoint_path is not None:
self.predictor = build_sam2_video_predictor(model_cfg, sam_checkpoint_path)
if flux_kontext_path is not None:
self.flux_kontext = FluxKontextPipeline.from_pretrained(flux_kontext_path, torch_dtype=torch.bfloat16).to("cuda")
def __call__(self, video_path, refer_image_path, output_path, resolution_area=[1280, 720], fps=30, iterations=3, k=7, w_len=1, h_len=1, retarget_flag=False, use_flux=False, replace_flag=False):
if replace_flag:
video_reader = VideoReader(video_path)
frame_num = len(video_reader)
print('frame_num: {}'.format(frame_num))
video_fps = video_reader.get_avg_fps()
print('video_fps: {}'.format(video_fps))
print('fps: {}'.format(fps))
# TODO: Maybe we can switch to PyAV later, which can get accurate frame num
duration = video_reader.get_frame_timestamp(-1)[-1]
expected_frame_num = int(duration * video_fps + 0.5)
ratio = abs((frame_num - expected_frame_num)/frame_num)
if ratio > 0.1:
print("Warning: The difference between the actual number of frames and the expected number of frames is two large")
frame_num = expected_frame_num
if fps == -1:
fps = video_fps
target_num = int(frame_num / video_fps * fps)
print('target_num: {}'.format(target_num))
idxs = get_frame_indices(frame_num, video_fps, target_num, fps)
frames = video_reader.get_batch(idxs).asnumpy()
frames = [resize_by_area(frame, resolution_area[0] * resolution_area[1], divisor=16) for frame in frames]
height, width = frames[0].shape[:2]
logger.info(f"Processing pose meta")
tpl_pose_metas = self.pose2d(frames)
face_images = []
for idx, meta in enumerate(tpl_pose_metas):
face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3,
image_shape=(frames[0].shape[0], frames[0].shape[1]))
x1, x2, y1, y2 = face_bbox_for_image
face_image = frames[idx][y1:y2, x1:x2]
face_image = cv2.resize(face_image, (512, 512))
face_images.append(face_image)
logger.info(f"Processing reference image: {refer_image_path}")
refer_img = cv2.imread(refer_image_path)
src_ref_path = os.path.join(output_path, 'src_ref.png')
shutil.copy(refer_image_path, src_ref_path)
refer_img = refer_img[..., ::-1]
refer_img = padding_resize(refer_img, height, width)
logger.info(f"Processing template video: {video_path}")
tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas]
cond_images = []
for idx, meta in enumerate(tpl_retarget_pose_metas):
canvas = np.zeros_like(refer_img)
conditioning_image = draw_aapose_by_meta_new(canvas, meta)
cond_images.append(conditioning_image)
masks = self.get_mask(frames, 400, tpl_pose_metas)
bg_images = []
aug_masks = []
for frame, mask in zip(frames, masks):
if iterations > 0:
_, each_mask = get_mask_body_img(frame, mask, iterations=iterations, k=k)
each_aug_mask = get_aug_mask(each_mask, w_len=w_len, h_len=h_len)
else:
each_aug_mask = mask
each_bg_image = frame * (1 - each_aug_mask[:, :, None])
bg_images.append(each_bg_image)
aug_masks.append(each_aug_mask)
src_face_path = os.path.join(output_path, 'src_face.mp4')
mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path)
src_pose_path = os.path.join(output_path, 'src_pose.mp4')
mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path)
src_bg_path = os.path.join(output_path, 'src_bg.mp4')
mpy.ImageSequenceClip(bg_images, fps=fps).write_videofile(src_bg_path)
aug_masks_new = [np.stack([mask * 255, mask * 255, mask * 255], axis=2) for mask in aug_masks]
src_mask_path = os.path.join(output_path, 'src_mask.mp4')
mpy.ImageSequenceClip(aug_masks_new, fps=fps).write_videofile(src_mask_path)
return True
else:
logger.info(f"Processing reference image: {refer_image_path}")
refer_img = cv2.imread(refer_image_path)
src_ref_path = os.path.join(output_path, 'src_ref.png')
shutil.copy(refer_image_path, src_ref_path)
refer_img = refer_img[..., ::-1]
refer_img = resize_by_area(refer_img, resolution_area[0] * resolution_area[1], divisor=16)
refer_pose_meta = self.pose2d([refer_img])[0]
logger.info(f"Processing template video: {video_path}")
video_reader = VideoReader(video_path)
frame_num = len(video_reader)
print('frame_num: {}'.format(frame_num))
video_fps = video_reader.get_avg_fps()
print('video_fps: {}'.format(video_fps))
print('fps: {}'.format(fps))
# TODO: Maybe we can switch to PyAV later, which can get accurate frame num
duration = video_reader.get_frame_timestamp(-1)[-1]
expected_frame_num = int(duration * video_fps + 0.5)
ratio = abs((frame_num - expected_frame_num)/frame_num)
if ratio > 0.1:
print("Warning: The difference between the actual number of frames and the expected number of frames is two large")
frame_num = expected_frame_num
if fps == -1:
fps = video_fps
target_num = int(frame_num / video_fps * fps)
print('target_num: {}'.format(target_num))
idxs = get_frame_indices(frame_num, video_fps, target_num, fps)
frames = video_reader.get_batch(idxs).asnumpy()
logger.info(f"Processing pose meta")
tpl_pose_meta0 = self.pose2d(frames[:1])[0]
tpl_pose_metas = self.pose2d(frames)
face_images = []
for idx, meta in enumerate(tpl_pose_metas):
face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3,
image_shape=(frames[0].shape[0], frames[0].shape[1]))
x1, x2, y1, y2 = face_bbox_for_image
face_image = frames[idx][y1:y2, x1:x2]
face_image = cv2.resize(face_image, (512, 512))
face_images.append(face_image)
if retarget_flag:
if use_flux:
tpl_prompt, refer_prompt = self.get_editing_prompts(tpl_pose_metas, refer_pose_meta)
refer_input = Image.fromarray(refer_img)
refer_edit = self.flux_kontext(
image=refer_input,
height=refer_img.shape[0],
width=refer_img.shape[1],
prompt=refer_prompt,
guidance_scale=2.5,
num_inference_steps=28,
).images[0]
refer_edit = Image.fromarray(padding_resize(np.array(refer_edit), refer_img.shape[0], refer_img.shape[1]))
refer_edit_path = os.path.join(output_path, 'refer_edit.png')
refer_edit.save(refer_edit_path)
refer_edit_pose_meta = self.pose2d([np.array(refer_edit)])[0]
tpl_img = frames[1]
tpl_input = Image.fromarray(tpl_img)
tpl_edit = self.flux_kontext(
image=tpl_input,
height=tpl_img.shape[0],
width=tpl_img.shape[1],
prompt=tpl_prompt,
guidance_scale=2.5,
num_inference_steps=28,
).images[0]
tpl_edit = Image.fromarray(padding_resize(np.array(tpl_edit), tpl_img.shape[0], tpl_img.shape[1]))
tpl_edit_path = os.path.join(output_path, 'tpl_edit.png')
tpl_edit.save(tpl_edit_path)
tpl_edit_pose_meta0 = self.pose2d([np.array(tpl_edit)])[0]
tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, tpl_edit_pose_meta0, refer_edit_pose_meta)
else:
tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, None, None)
else:
tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas]
cond_images = []
for idx, meta in enumerate(tpl_retarget_pose_metas):
if retarget_flag:
canvas = np.zeros_like(refer_img)
conditioning_image = draw_aapose_by_meta_new(canvas, meta)
else:
canvas = np.zeros_like(frames[0])
conditioning_image = draw_aapose_by_meta_new(canvas, meta)
conditioning_image = padding_resize(conditioning_image, refer_img.shape[0], refer_img.shape[1])
cond_images.append(conditioning_image)
src_face_path = os.path.join(output_path, 'src_face.mp4')
mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path)
src_pose_path = os.path.join(output_path, 'src_pose.mp4')
mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path)
return True
def get_editing_prompts(self, tpl_pose_metas, refer_pose_meta):
arm_visible = False
leg_visible = False
for tpl_pose_meta in tpl_pose_metas:
tpl_keypoints = tpl_pose_meta['keypoints_body']
if tpl_keypoints[3].all() != 0 or tpl_keypoints[4].all() != 0 or tpl_keypoints[6].all() != 0 or tpl_keypoints[7].all() != 0:
if (tpl_keypoints[3][0] <= 1 and tpl_keypoints[3][1] <= 1 and tpl_keypoints[3][2] >= 0.75) or (tpl_keypoints[4][0] <= 1 and tpl_keypoints[4][1] <= 1 and tpl_keypoints[4][2] >= 0.75) or \
(tpl_keypoints[6][0] <= 1 and tpl_keypoints[6][1] <= 1 and tpl_keypoints[6][2] >= 0.75) or (tpl_keypoints[7][0] <= 1 and tpl_keypoints[7][1] <= 1 and tpl_keypoints[7][2] >= 0.75):
arm_visible = True
if tpl_keypoints[9].all() != 0 or tpl_keypoints[12].all() != 0 or tpl_keypoints[10].all() != 0 or tpl_keypoints[13].all() != 0:
if (tpl_keypoints[9][0] <= 1 and tpl_keypoints[9][1] <= 1 and tpl_keypoints[9][2] >= 0.75) or (tpl_keypoints[12][0] <= 1 and tpl_keypoints[12][1] <= 1 and tpl_keypoints[12][2] >= 0.75) or \
(tpl_keypoints[10][0] <= 1 and tpl_keypoints[10][1] <= 1 and tpl_keypoints[10][2] >= 0.75) or (tpl_keypoints[13][0] <= 1 and tpl_keypoints[13][1] <= 1 and tpl_keypoints[13][2] >= 0.75):
leg_visible = True
if arm_visible and leg_visible:
break
if leg_visible:
if tpl_pose_meta['width'] > tpl_pose_meta['height']:
tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image."
else:
tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image."
if refer_pose_meta['width'] > refer_pose_meta['height']:
refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image."
else:
refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image."
elif arm_visible:
if tpl_pose_meta['width'] > tpl_pose_meta['height']:
tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image."
else:
tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image."
if refer_pose_meta['width'] > refer_pose_meta['height']:
refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image."
else:
refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image."
else:
tpl_prompt = "Change the person to face forward."
refer_prompt = "Change the person to face forward."
return tpl_prompt, refer_prompt
def get_mask(self, frames, th_step, kp2ds_all):
frame_num = len(frames)
if frame_num < th_step:
num_step = 1
else:
num_step = (frame_num + th_step) // th_step
all_mask = []
for index in range(num_step):
each_frames = frames[index * th_step:(index + 1) * th_step]
kp2ds = kp2ds_all[index * th_step:(index + 1) * th_step]
if len(each_frames) > 4:
key_frame_num = 4
elif 4 >= len(each_frames) > 0:
key_frame_num = 1
else:
continue
key_frame_step = len(kp2ds) // key_frame_num
key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))
key_points_index = [0, 1, 2, 5, 8, 11, 10, 13]
key_frame_body_points_list = []
for key_frame_index in key_frame_index_list:
keypoints_body_list = []
body_key_points = kp2ds[key_frame_index]['keypoints_body']
for each_index in key_points_index:
each_keypoint = body_key_points[each_index]
if None is each_keypoint:
continue
keypoints_body_list.append(each_keypoint)
keypoints_body = np.array(keypoints_body_list)[:, :2]
wh = np.array([[kp2ds[0]['width'], kp2ds[0]['height']]])
points = (keypoints_body * wh).astype(np.int32)
key_frame_body_points_list.append(points)
inference_state = self.predictor.init_state_v2(frames=each_frames)
self.predictor.reset_state(inference_state)
ann_obj_id = 1
for ann_frame_idx, points in zip(key_frame_index_list, key_frame_body_points_list):
labels = np.array([1] * points.shape[0], np.int32)
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
for out_frame_idx in range(len(video_segments)):
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
out_mask = out_mask[0].astype(np.uint8)
all_mask.append(out_mask)
return all_mask
def convert_list_to_array(self, metas):
metas_list = []
for meta in metas:
for key, value in meta.items():
if type(value) is list:
value = np.array(value)
meta[key] = value
metas_list.append(meta)
return metas_list
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import cv2
import numpy as np
import json
from tqdm import tqdm
import math
from typing import NamedTuple, List
import copy
from pose2d_utils import AAPoseMeta
# load skeleton name and bone lines
keypoint_list = [
"Nose",
"Neck",
"RShoulder",
"RElbow",
"RWrist", # No.4
"LShoulder",
"LElbow",
"LWrist", # No.7
"RHip",
"RKnee",
"RAnkle", # No.10
"LHip",
"LKnee",
"LAnkle", # No.13
"REye",
"LEye",
"REar",
"LEar",
"LToe",
"RToe",
]
limbSeq = [
[2, 3], [2, 6], # shoulders
[3, 4], [4, 5], # left arm
[6, 7], [7, 8], # right arm
[2, 9], [9, 10], [10, 11], # right leg
[2, 12], [12, 13], [13, 14], # left leg
[2, 1], [1, 15], [15, 17], [1, 16], [16, 18], # face (nose, eyes, ears)
[14, 19], # left foot
[11, 20] # right foot
]
eps = 0.01
class Keypoint(NamedTuple):
x: float
y: float
score: float = 1.0
id: int = -1
# for each limb, calculate src & dst bone's length
# and calculate their ratios
def get_length(skeleton, limb):
k1_index, k2_index = limb
H, W = skeleton['height'], skeleton['width']
keypoints = skeleton['keypoints_body']
keypoint1 = keypoints[k1_index - 1]
keypoint2 = keypoints[k2_index - 1]
if keypoint1 is None or keypoint2 is None:
return None, None, None
X = np.array([keypoint1[0], keypoint2[0]]) * float(W)
Y = np.array([keypoint1[1], keypoint2[1]]) * float(H)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
return X, Y, length
def get_handpose_meta(keypoints, delta, src_H, src_W):
new_keypoints = []
for idx, keypoint in enumerate(keypoints):
if keypoint is None:
new_keypoints.append(None)
continue
if keypoint.score == 0:
new_keypoints.append(None)
continue
x, y = keypoint.x, keypoint.y
x = int(x * src_W + delta[0])
y = int(y * src_H + delta[1])
new_keypoints.append(
Keypoint(
x=x,
y=y,
score=keypoint.score,
))
return new_keypoints
def deal_hand_keypoints(hand_res, r_ratio, l_ratio, hand_score_th = 0.5):
left_hand = []
right_hand = []
left_delta_x = hand_res['left'][0][0] * (l_ratio - 1)
left_delta_y = hand_res['left'][0][1] * (l_ratio - 1)
right_delta_x = hand_res['right'][0][0] * (r_ratio - 1)
right_delta_y = hand_res['right'][0][1] * (r_ratio - 1)
length = len(hand_res['left'])
for i in range(length):
# left hand
if hand_res['left'][i][2] < hand_score_th:
left_hand.append(
Keypoint(
x=-1,
y=-1,
score=0,
)
)
else:
left_hand.append(
Keypoint(
x=hand_res['left'][i][0] * l_ratio - left_delta_x,
y=hand_res['left'][i][1] * l_ratio - left_delta_y,
score = hand_res['left'][i][2]
)
)
# right hand
if hand_res['right'][i][2] < hand_score_th:
right_hand.append(
Keypoint(
x=-1,
y=-1,
score=0,
)
)
else:
right_hand.append(
Keypoint(
x=hand_res['right'][i][0] * r_ratio - right_delta_x,
y=hand_res['right'][i][1] * r_ratio - right_delta_y,
score = hand_res['right'][i][2]
)
)
return right_hand, left_hand
def get_scaled_pose(canvas, src_canvas, keypoints, keypoints_hand, bone_ratio_list, delta_ground_x, delta_ground_y,
rescaled_src_ground_x, body_flag, id, scale_min, threshold = 0.4):
H, W = canvas
src_H, src_W = src_canvas
new_length_list = [ ]
angle_list = [ ]
# keypoints from 0-1 to H/W range
for idx in range(len(keypoints)):
if keypoints[idx] is None or len(keypoints[idx]) == 0:
continue
keypoints[idx] = [keypoints[idx][0] * src_W, keypoints[idx][1] * src_H, keypoints[idx][2]]
# first traverse, get new_length_list and angle_list
for idx, (k1_index, k2_index) in enumerate(limbSeq):
keypoint1 = keypoints[k1_index - 1]
keypoint2 = keypoints[k2_index - 1]
if keypoint1 is None or keypoint2 is None or len(keypoint1) == 0 or len(keypoint2) == 0:
new_length_list.append(None)
angle_list.append(None)
continue
Y = np.array([keypoint1[0], keypoint2[0]]) #* float(W)
X = np.array([keypoint1[1], keypoint2[1]]) #* float(H)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
new_length = length * bone_ratio_list[idx]
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
new_length_list.append(new_length)
angle_list.append(angle)
# Keep foot length within 0.5x calf length
foot_lower_leg_ratio = 0.5
if new_length_list[8] != None and new_length_list[18] != None:
if new_length_list[18] > new_length_list[8] * foot_lower_leg_ratio:
new_length_list[18] = new_length_list[8] * foot_lower_leg_ratio
if new_length_list[11] != None and new_length_list[17] != None:
if new_length_list[17] > new_length_list[11] * foot_lower_leg_ratio:
new_length_list[17] = new_length_list[11] * foot_lower_leg_ratio
# second traverse, calculate new keypoints
rescale_keypoints = keypoints.copy()
for idx, (k1_index, k2_index) in enumerate(limbSeq):
# update dst_keypoints
start_keypoint = rescale_keypoints[k1_index - 1]
new_length = new_length_list[idx]
angle = angle_list[idx]
if rescale_keypoints[k1_index - 1] is None or rescale_keypoints[k2_index - 1] is None or \
len(rescale_keypoints[k1_index - 1]) == 0 or len(rescale_keypoints[k2_index - 1]) == 0:
continue
# calculate end_keypoint
delta_x = new_length * math.cos(math.radians(angle))
delta_y = new_length * math.sin(math.radians(angle))
end_keypoint_x = start_keypoint[0] - delta_x
end_keypoint_y = start_keypoint[1] - delta_y
# update keypoints
rescale_keypoints[k2_index - 1] = [end_keypoint_x, end_keypoint_y, rescale_keypoints[k2_index - 1][2]]
if id == 0:
if body_flag == 'full_body' and rescale_keypoints[8] != None and rescale_keypoints[11] != None:
delta_ground_x_offset_first_frame = (rescale_keypoints[8][0] + rescale_keypoints[11][0]) / 2 - rescaled_src_ground_x
delta_ground_x += delta_ground_x_offset_first_frame
elif body_flag == 'half_body' and rescale_keypoints[1] != None:
delta_ground_x_offset_first_frame = rescale_keypoints[1][0] - rescaled_src_ground_x
delta_ground_x += delta_ground_x_offset_first_frame
# offset all keypoints
for idx in range(len(rescale_keypoints)):
if rescale_keypoints[idx] is None or len(rescale_keypoints[idx]) == 0 :
continue
rescale_keypoints[idx][0] -= delta_ground_x
rescale_keypoints[idx][1] -= delta_ground_y
# rescale keypoints to original size
rescale_keypoints[idx][0] /= scale_min
rescale_keypoints[idx][1] /= scale_min
# Scale hand proportions based on body skeletal ratios
r_ratio = max(bone_ratio_list[0], bone_ratio_list[1]) / scale_min
l_ratio = max(bone_ratio_list[0], bone_ratio_list[1]) / scale_min
left_hand, right_hand = deal_hand_keypoints(keypoints_hand, r_ratio, l_ratio, hand_score_th = threshold)
left_hand_new = left_hand.copy()
right_hand_new = right_hand.copy()
if rescale_keypoints[4] == None and rescale_keypoints[7] == None:
pass
elif rescale_keypoints[4] == None and rescale_keypoints[7] != None:
right_hand_delta = np.array(rescale_keypoints[7][:2]) - np.array(keypoints[7][:2])
right_hand_new = get_handpose_meta(right_hand, right_hand_delta, src_H, src_W)
elif rescale_keypoints[4] != None and rescale_keypoints[7] == None:
left_hand_delta = np.array(rescale_keypoints[4][:2]) - np.array(keypoints[4][:2])
left_hand_new = get_handpose_meta(left_hand, left_hand_delta, src_H, src_W)
else:
# get left_hand and right_hand offset
left_hand_delta = np.array(rescale_keypoints[4][:2]) - np.array(keypoints[4][:2])
right_hand_delta = np.array(rescale_keypoints[7][:2]) - np.array(keypoints[7][:2])
if keypoints[4][0] != None and left_hand[0].x != -1:
left_hand_root_offset = np.array( ( keypoints[4][0] - left_hand[0].x * src_W, keypoints[4][1] - left_hand[0].y * src_H))
left_hand_delta += left_hand_root_offset
if keypoints[7][0] != None and right_hand[0].x != -1:
right_hand_root_offset = np.array( ( keypoints[7][0] - right_hand[0].x * src_W, keypoints[7][1] - right_hand[0].y * src_H))
right_hand_delta += right_hand_root_offset
dis_left_hand = ((keypoints[4][0] - left_hand[0].x * src_W) ** 2 + (keypoints[4][1] - left_hand[0].y * src_H) ** 2) ** 0.5
dis_right_hand = ((keypoints[7][0] - left_hand[0].x * src_W) ** 2 + (keypoints[7][1] - left_hand[0].y * src_H) ** 2) ** 0.5
if dis_left_hand > dis_right_hand:
right_hand_new = get_handpose_meta(left_hand, right_hand_delta, src_H, src_W)
left_hand_new = get_handpose_meta(right_hand, left_hand_delta, src_H, src_W)
else:
left_hand_new = get_handpose_meta(left_hand, left_hand_delta, src_H, src_W)
right_hand_new = get_handpose_meta(right_hand, right_hand_delta, src_H, src_W)
# get normalized keypoints_body
norm_body_keypoints = [ ]
for body_keypoint in rescale_keypoints:
if body_keypoint != None:
norm_body_keypoints.append([body_keypoint[0] / W , body_keypoint[1] / H, body_keypoint[2]])
else:
norm_body_keypoints.append(None)
frame_info = {
'height': H,
'width': W,
'keypoints_body': norm_body_keypoints,
'keypoints_left_hand' : left_hand_new,
'keypoints_right_hand' : right_hand_new,
}
return frame_info
def rescale_skeleton(H, W, keypoints, bone_ratio_list):
rescale_keypoints = keypoints.copy()
new_length_list = [ ]
angle_list = [ ]
# keypoints from 0-1 to H/W range
for idx in range(len(rescale_keypoints)):
if rescale_keypoints[idx] is None or len(rescale_keypoints[idx]) == 0:
continue
rescale_keypoints[idx] = [rescale_keypoints[idx][0] * W, rescale_keypoints[idx][1] * H]
# first traverse, get new_length_list and angle_list
for idx, (k1_index, k2_index) in enumerate(limbSeq):
keypoint1 = rescale_keypoints[k1_index - 1]
keypoint2 = rescale_keypoints[k2_index - 1]
if keypoint1 is None or keypoint2 is None or len(keypoint1) == 0 or len(keypoint2) == 0:
new_length_list.append(None)
angle_list.append(None)
continue
Y = np.array([keypoint1[0], keypoint2[0]]) #* float(W)
X = np.array([keypoint1[1], keypoint2[1]]) #* float(H)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
new_length = length * bone_ratio_list[idx]
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
new_length_list.append(new_length)
angle_list.append(angle)
# # second traverse, calculate new keypoints
for idx, (k1_index, k2_index) in enumerate(limbSeq):
# update dst_keypoints
start_keypoint = rescale_keypoints[k1_index - 1]
new_length = new_length_list[idx]
angle = angle_list[idx]
if rescale_keypoints[k1_index - 1] is None or rescale_keypoints[k2_index - 1] is None or \
len(rescale_keypoints[k1_index - 1]) == 0 or len(rescale_keypoints[k2_index - 1]) == 0:
continue
# calculate end_keypoint
delta_x = new_length * math.cos(math.radians(angle))
delta_y = new_length * math.sin(math.radians(angle))
end_keypoint_x = start_keypoint[0] - delta_x
end_keypoint_y = start_keypoint[1] - delta_y
# update keypoints
rescale_keypoints[k2_index - 1] = [end_keypoint_x, end_keypoint_y]
return rescale_keypoints
def fix_lack_keypoints_use_sym(skeleton):
keypoints = skeleton['keypoints_body']
H, W = skeleton['height'], skeleton['width']
limb_points_list = [
[3, 4, 5],
[6, 7, 8],
[12, 13, 14, 19],
[9, 10, 11, 20],
]
for limb_points in limb_points_list:
miss_flag = False
for point in limb_points:
if keypoints[point - 1] is None:
miss_flag = True
continue
if miss_flag:
skeleton['keypoints_body'][point - 1] = None
repair_limb_seq_left = [
[3, 4], [4, 5], # left arm
[12, 13], [13, 14], # left leg
[14, 19] # left foot
]
repair_limb_seq_right = [
[6, 7], [7, 8], # right arm
[9, 10], [10, 11], # right leg
[11, 20] # right foot
]
repair_limb_seq = [repair_limb_seq_left, repair_limb_seq_right]
for idx_part, part in enumerate(repair_limb_seq):
for idx, limb in enumerate(part):
k1_index, k2_index = limb
keypoint1 = keypoints[k1_index - 1]
keypoint2 = keypoints[k2_index - 1]
if keypoint1 != None and keypoint2 is None:
# reference to symmetric limb
sym_limb = repair_limb_seq[1-idx_part][idx]
k1_index_sym, k2_index_sym = sym_limb
keypoint1_sym = keypoints[k1_index_sym - 1]
keypoint2_sym = keypoints[k2_index_sym - 1]
ref_length = 0
if keypoint1_sym != None and keypoint2_sym != None:
X = np.array([keypoint1_sym[0], keypoint2_sym[0]]) * float(W)
Y = np.array([keypoint1_sym[1], keypoint2_sym[1]]) * float(H)
ref_length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
else:
ref_length_left, ref_length_right = 0, 0
if keypoints[1] != None and keypoints[8] != None:
X = np.array([keypoints[1][0], keypoints[8][0]]) * float(W)
Y = np.array([keypoints[1][1], keypoints[8][1]]) * float(H)
ref_length_left = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
if idx <= 1: # arms
ref_length_left /= 2
if keypoints[1] != None and keypoints[11] != None:
X = np.array([keypoints[1][0], keypoints[11][0]]) * float(W)
Y = np.array([keypoints[1][1], keypoints[11][1]]) * float(H)
ref_length_right = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
if idx <= 1: # arms
ref_length_right /= 2
elif idx == 4: # foot
ref_length_right /= 5
ref_length = max(ref_length_left, ref_length_right)
if ref_length != 0:
skeleton['keypoints_body'][k2_index - 1] = [0, 0] #init
skeleton['keypoints_body'][k2_index - 1][0] = skeleton['keypoints_body'][k1_index - 1][0]
skeleton['keypoints_body'][k2_index - 1][1] = skeleton['keypoints_body'][k1_index - 1][1] + ref_length / H
return skeleton
def rescale_shorten_skeleton(ratio_list, src_length_list, dst_length_list):
modify_bone_list = [
[0, 1],
[2, 4],
[3, 5],
[6, 9],
[7, 10],
[8, 11],
[17, 18]
]
for modify_bone in modify_bone_list:
new_ratio = max(ratio_list[modify_bone[0]], ratio_list[modify_bone[1]])
ratio_list[modify_bone[0]] = new_ratio
ratio_list[modify_bone[1]] = new_ratio
if ratio_list[13]!= None and ratio_list[15]!= None:
ratio_eye_avg = (ratio_list[13] + ratio_list[15]) / 2
ratio_list[13] = ratio_eye_avg
ratio_list[15] = ratio_eye_avg
if ratio_list[14]!= None and ratio_list[16]!= None:
ratio_eye_avg = (ratio_list[14] + ratio_list[16]) / 2
ratio_list[14] = ratio_eye_avg
ratio_list[16] = ratio_eye_avg
return ratio_list, src_length_list, dst_length_list
def check_full_body(keypoints, threshold = 0.4):
body_flag = 'half_body'
# 1. If ankle points exist, confidence is greater than the threshold, and points do not exceed the frame, return full_body
if keypoints[10] != None and keypoints[13] != None and keypoints[8] != None and keypoints[11] != None:
if (keypoints[10][1] <= 1 and keypoints[13][1] <= 1) and (keypoints[10][2] >= threshold and keypoints[13][2] >= threshold) and \
(keypoints[8][1] <= 1 and keypoints[11][1] <= 1) and (keypoints[8][2] >= threshold and keypoints[11][2] >= threshold):
body_flag = 'full_body'
return body_flag
# 2. If hip points exist, return three_quarter_body
if (keypoints[8] != None and keypoints[11] != None):
if (keypoints[8][1] <= 1 and keypoints[11][1] <= 1) and (keypoints[8][2] >= threshold and keypoints[11][2] >= threshold):
body_flag = 'three_quarter_body'
return body_flag
return body_flag
def check_full_body_both(flag1, flag2):
body_flag_dict = {
'full_body': 2,
'three_quarter_body' : 1,
'half_body': 0
}
body_flag_dict_reverse = {
2: 'full_body',
1: 'three_quarter_body',
0: 'half_body'
}
flag1_num = body_flag_dict[flag1]
flag2_num = body_flag_dict[flag2]
flag_both_num = min(flag1_num, flag2_num)
return body_flag_dict_reverse[flag_both_num]
def write_to_poses(data_to_json, none_idx, dst_shape, bone_ratio_list, delta_ground_x, delta_ground_y, rescaled_src_ground_x, body_flag, scale_min):
outputs = []
length = len(data_to_json)
for id in tqdm(range(length)):
src_height, src_width = data_to_json[id]['height'], data_to_json[id]['width']
width, height = dst_shape
keypoints = data_to_json[id]['keypoints_body']
for idx in range(len(keypoints)):
if idx in none_idx:
keypoints[idx] = None
new_keypoints = keypoints.copy()
# get hand keypoints
keypoints_hand = {'left' : data_to_json[id]['keypoints_left_hand'], 'right' : data_to_json[id]['keypoints_right_hand']}
# Normalize hand coordinates to 0-1 range
for hand_idx in range(len(data_to_json[id]['keypoints_left_hand'])):
data_to_json[id]['keypoints_left_hand'][hand_idx][0] = data_to_json[id]['keypoints_left_hand'][hand_idx][0] / src_width
data_to_json[id]['keypoints_left_hand'][hand_idx][1] = data_to_json[id]['keypoints_left_hand'][hand_idx][1] / src_height
for hand_idx in range(len(data_to_json[id]['keypoints_right_hand'])):
data_to_json[id]['keypoints_right_hand'][hand_idx][0] = data_to_json[id]['keypoints_right_hand'][hand_idx][0] / src_width
data_to_json[id]['keypoints_right_hand'][hand_idx][1] = data_to_json[id]['keypoints_right_hand'][hand_idx][1] / src_height
frame_info = get_scaled_pose((height, width), (src_height, src_width), new_keypoints, keypoints_hand, bone_ratio_list, delta_ground_x, delta_ground_y, rescaled_src_ground_x, body_flag, id, scale_min)
outputs.append(frame_info)
return outputs
def calculate_scale_ratio(skeleton, skeleton_edit, scale_ratio_flag):
if scale_ratio_flag:
headw = max(skeleton['keypoints_body'][0][0], skeleton['keypoints_body'][14][0], skeleton['keypoints_body'][15][0], skeleton['keypoints_body'][16][0], skeleton['keypoints_body'][17][0]) - \
min(skeleton['keypoints_body'][0][0], skeleton['keypoints_body'][14][0], skeleton['keypoints_body'][15][0], skeleton['keypoints_body'][16][0], skeleton['keypoints_body'][17][0])
headw_edit = max(skeleton_edit['keypoints_body'][0][0], skeleton_edit['keypoints_body'][14][0], skeleton_edit['keypoints_body'][15][0], skeleton_edit['keypoints_body'][16][0], skeleton_edit['keypoints_body'][17][0]) - \
min(skeleton_edit['keypoints_body'][0][0], skeleton_edit['keypoints_body'][14][0], skeleton_edit['keypoints_body'][15][0], skeleton_edit['keypoints_body'][16][0], skeleton_edit['keypoints_body'][17][0])
headw_ratio = headw / headw_edit
_, _, shoulder = get_length(skeleton, [6,3])
_, _, shoulder_edit = get_length(skeleton_edit, [6,3])
shoulder_ratio = shoulder / shoulder_edit
return max(headw_ratio, shoulder_ratio)
else:
return 1
def retarget_pose(src_skeleton, dst_skeleton, all_src_skeleton, src_skeleton_edit, dst_skeleton_edit, threshold=0.4):
if src_skeleton_edit is not None and dst_skeleton_edit is not None:
use_edit_for_base = True
else:
use_edit_for_base = False
src_skeleton_ori = copy.deepcopy(src_skeleton)
dst_skeleton_ori_h, dst_skeleton_ori_w = dst_skeleton['height'], dst_skeleton['width']
if src_skeleton['keypoints_body'][0] != None and src_skeleton['keypoints_body'][10] != None and src_skeleton['keypoints_body'][13] != None and \
dst_skeleton['keypoints_body'][0] != None and dst_skeleton['keypoints_body'][10] != None and dst_skeleton['keypoints_body'][13] != None and \
src_skeleton['keypoints_body'][0][2] > 0.5 and src_skeleton['keypoints_body'][10][2] > 0.5 and src_skeleton['keypoints_body'][13][2] > 0.5 and \
dst_skeleton['keypoints_body'][0][2] > 0.5 and dst_skeleton['keypoints_body'][10][2] > 0.5 and dst_skeleton['keypoints_body'][13][2] > 0.5:
src_height = src_skeleton['height'] * abs(
(src_skeleton['keypoints_body'][10][1] + src_skeleton['keypoints_body'][13][1]) / 2 -
src_skeleton['keypoints_body'][0][1])
dst_height = dst_skeleton['height'] * abs(
(dst_skeleton['keypoints_body'][10][1] + dst_skeleton['keypoints_body'][13][1]) / 2 -
dst_skeleton['keypoints_body'][0][1])
scale_min = 1.0 * src_height / dst_height
elif src_skeleton['keypoints_body'][0] != None and src_skeleton['keypoints_body'][8] != None and src_skeleton['keypoints_body'][11] != None and \
dst_skeleton['keypoints_body'][0] != None and dst_skeleton['keypoints_body'][8] != None and dst_skeleton['keypoints_body'][11] != None and \
src_skeleton['keypoints_body'][0][2] > 0.5 and src_skeleton['keypoints_body'][8][2] > 0.5 and src_skeleton['keypoints_body'][11][2] > 0.5 and \
dst_skeleton['keypoints_body'][0][2] > 0.5 and dst_skeleton['keypoints_body'][8][2] > 0.5 and dst_skeleton['keypoints_body'][11][2] > 0.5:
src_height = src_skeleton['height'] * abs(
(src_skeleton['keypoints_body'][8][1] + src_skeleton['keypoints_body'][11][1]) / 2 -
src_skeleton['keypoints_body'][0][1])
dst_height = dst_skeleton['height'] * abs(
(dst_skeleton['keypoints_body'][8][1] + dst_skeleton['keypoints_body'][11][1]) / 2 -
dst_skeleton['keypoints_body'][0][1])
scale_min = 1.0 * src_height / dst_height
else:
scale_min = np.sqrt(src_skeleton['height'] * src_skeleton['width']) / np.sqrt(dst_skeleton['height'] * dst_skeleton['width'])
if use_edit_for_base:
scale_ratio_flag = False
if src_skeleton_edit['keypoints_body'][0] != None and src_skeleton_edit['keypoints_body'][10] != None and src_skeleton_edit['keypoints_body'][13] != None and \
dst_skeleton_edit['keypoints_body'][0] != None and dst_skeleton_edit['keypoints_body'][10] != None and dst_skeleton_edit['keypoints_body'][13] != None and \
src_skeleton_edit['keypoints_body'][0][2] > 0.5 and src_skeleton_edit['keypoints_body'][10][2] > 0.5 and src_skeleton_edit['keypoints_body'][13][2] > 0.5 and \
dst_skeleton_edit['keypoints_body'][0][2] > 0.5 and dst_skeleton_edit['keypoints_body'][10][2] > 0.5 and dst_skeleton_edit['keypoints_body'][13][2] > 0.5:
src_height_edit = src_skeleton_edit['height'] * abs(
(src_skeleton_edit['keypoints_body'][10][1] + src_skeleton_edit['keypoints_body'][13][1]) / 2 -
src_skeleton_edit['keypoints_body'][0][1])
dst_height_edit = dst_skeleton_edit['height'] * abs(
(dst_skeleton_edit['keypoints_body'][10][1] + dst_skeleton_edit['keypoints_body'][13][1]) / 2 -
dst_skeleton_edit['keypoints_body'][0][1])
scale_min_edit = 1.0 * src_height_edit / dst_height_edit
elif src_skeleton_edit['keypoints_body'][0] != None and src_skeleton_edit['keypoints_body'][8] != None and src_skeleton_edit['keypoints_body'][11] != None and \
dst_skeleton_edit['keypoints_body'][0] != None and dst_skeleton_edit['keypoints_body'][8] != None and dst_skeleton_edit['keypoints_body'][11] != None and \
src_skeleton_edit['keypoints_body'][0][2] > 0.5 and src_skeleton_edit['keypoints_body'][8][2] > 0.5 and src_skeleton_edit['keypoints_body'][11][2] > 0.5 and \
dst_skeleton_edit['keypoints_body'][0][2] > 0.5 and dst_skeleton_edit['keypoints_body'][8][2] > 0.5 and dst_skeleton_edit['keypoints_body'][11][2] > 0.5:
src_height_edit = src_skeleton_edit['height'] * abs(
(src_skeleton_edit['keypoints_body'][8][1] + src_skeleton_edit['keypoints_body'][11][1]) / 2 -
src_skeleton_edit['keypoints_body'][0][1])
dst_height_edit = dst_skeleton_edit['height'] * abs(
(dst_skeleton_edit['keypoints_body'][8][1] + dst_skeleton_edit['keypoints_body'][11][1]) / 2 -
dst_skeleton_edit['keypoints_body'][0][1])
scale_min_edit = 1.0 * src_height_edit / dst_height_edit
else:
scale_min_edit = np.sqrt(src_skeleton_edit['height'] * src_skeleton_edit['width']) / np.sqrt(dst_skeleton_edit['height'] * dst_skeleton_edit['width'])
scale_ratio_flag = True
# Flux may change the scale, compensate for it here
ratio_src = calculate_scale_ratio(src_skeleton, src_skeleton_edit, scale_ratio_flag)
ratio_dst = calculate_scale_ratio(dst_skeleton, dst_skeleton_edit, scale_ratio_flag)
dst_skeleton_edit['height'] = int(dst_skeleton_edit['height'] * scale_min_edit)
dst_skeleton_edit['width'] = int(dst_skeleton_edit['width'] * scale_min_edit)
for idx in range(len(dst_skeleton_edit['keypoints_left_hand'])):
dst_skeleton_edit['keypoints_left_hand'][idx][0] *= scale_min_edit
dst_skeleton_edit['keypoints_left_hand'][idx][1] *= scale_min_edit
for idx in range(len(dst_skeleton_edit['keypoints_right_hand'])):
dst_skeleton_edit['keypoints_right_hand'][idx][0] *= scale_min_edit
dst_skeleton_edit['keypoints_right_hand'][idx][1] *= scale_min_edit
dst_skeleton['height'] = int(dst_skeleton['height'] * scale_min)
dst_skeleton['width'] = int(dst_skeleton['width'] * scale_min)
for idx in range(len(dst_skeleton['keypoints_left_hand'])):
dst_skeleton['keypoints_left_hand'][idx][0] *= scale_min
dst_skeleton['keypoints_left_hand'][idx][1] *= scale_min
for idx in range(len(dst_skeleton['keypoints_right_hand'])):
dst_skeleton['keypoints_right_hand'][idx][0] *= scale_min
dst_skeleton['keypoints_right_hand'][idx][1] *= scale_min
dst_body_flag = check_full_body(dst_skeleton['keypoints_body'], threshold)
src_body_flag = check_full_body(src_skeleton_ori['keypoints_body'], threshold)
body_flag = check_full_body_both(dst_body_flag, src_body_flag)
#print('body_flag: ', body_flag)
if use_edit_for_base:
src_skeleton_edit = fix_lack_keypoints_use_sym(src_skeleton_edit)
dst_skeleton_edit = fix_lack_keypoints_use_sym(dst_skeleton_edit)
else:
src_skeleton = fix_lack_keypoints_use_sym(src_skeleton)
dst_skeleton = fix_lack_keypoints_use_sym(dst_skeleton)
none_idx = []
for idx in range(len(dst_skeleton['keypoints_body'])):
if dst_skeleton['keypoints_body'][idx] == None or src_skeleton['keypoints_body'][idx] == None:
src_skeleton['keypoints_body'][idx] = None
dst_skeleton['keypoints_body'][idx] = None
none_idx.append(idx)
# get bone ratio list
ratio_list, src_length_list, dst_length_list = [], [], []
for idx, limb in enumerate(limbSeq):
if use_edit_for_base:
src_X, src_Y, src_length = get_length(src_skeleton_edit, limb)
dst_X, dst_Y, dst_length = get_length(dst_skeleton_edit, limb)
if src_X is None or src_Y is None or dst_X is None or dst_Y is None:
ratio = -1
else:
ratio = 1.0 * dst_length * ratio_dst / src_length / ratio_src
else:
src_X, src_Y, src_length = get_length(src_skeleton, limb)
dst_X, dst_Y, dst_length = get_length(dst_skeleton, limb)
if src_X is None or src_Y is None or dst_X is None or dst_Y is None:
ratio = -1
else:
ratio = 1.0 * dst_length / src_length
ratio_list.append(ratio)
src_length_list.append(src_length)
dst_length_list.append(dst_length)
for idx, ratio in enumerate(ratio_list):
if ratio == -1:
if ratio_list[0] != -1 and ratio_list[1] != -1:
ratio_list[idx] = (ratio_list[0] + ratio_list[1]) / 2
# Consider adding constraints when Flux fails to correct head pose, causing neck issues.
# if ratio_list[12] > (ratio_list[0]+ratio_list[1])/2*1.25:
# ratio_list[12] = (ratio_list[0]+ratio_list[1])/2*1.25
ratio_list, src_length_list, dst_length_list = rescale_shorten_skeleton(ratio_list, src_length_list, dst_length_list)
rescaled_src_skeleton_ori = rescale_skeleton(src_skeleton_ori['height'], src_skeleton_ori['width'],
src_skeleton_ori['keypoints_body'], ratio_list)
# get global translation offset_x and offset_y
if body_flag == 'full_body':
#print('use foot mark.')
dst_ground_y = max(dst_skeleton['keypoints_body'][10][1], dst_skeleton['keypoints_body'][13][1]) * dst_skeleton[
'height']
# The midpoint between toe and ankle
if dst_skeleton['keypoints_body'][18] != None and dst_skeleton['keypoints_body'][19] != None:
right_foot_mid = (dst_skeleton['keypoints_body'][10][1] + dst_skeleton['keypoints_body'][19][1]) / 2
left_foot_mid = (dst_skeleton['keypoints_body'][13][1] + dst_skeleton['keypoints_body'][18][1]) / 2
dst_ground_y = max(left_foot_mid, right_foot_mid) * dst_skeleton['height']
rescaled_src_ground_y = max(rescaled_src_skeleton_ori[10][1], rescaled_src_skeleton_ori[13][1])
delta_ground_y = rescaled_src_ground_y - dst_ground_y
dst_ground_x = (dst_skeleton['keypoints_body'][8][0] + dst_skeleton['keypoints_body'][11][0]) * dst_skeleton[
'width'] / 2
rescaled_src_ground_x = (rescaled_src_skeleton_ori[8][0] + rescaled_src_skeleton_ori[11][0]) / 2
delta_ground_x = rescaled_src_ground_x - dst_ground_x
delta_x, delta_y = delta_ground_x, delta_ground_y
else:
#print('use neck mark.')
# use neck keypoint as mark
src_neck_y = rescaled_src_skeleton_ori[1][1]
dst_neck_y = dst_skeleton['keypoints_body'][1][1]
delta_neck_y = src_neck_y - dst_neck_y * dst_skeleton['height']
src_neck_x = rescaled_src_skeleton_ori[1][0]
dst_neck_x = dst_skeleton['keypoints_body'][1][0]
delta_neck_x = src_neck_x - dst_neck_x * dst_skeleton['width']
delta_x, delta_y = delta_neck_x, delta_neck_y
rescaled_src_ground_x = src_neck_x
dst_shape = (dst_skeleton_ori_w, dst_skeleton_ori_h)
output = write_to_poses(all_src_skeleton, none_idx, dst_shape, ratio_list, delta_x, delta_y,
rescaled_src_ground_x, body_flag, scale_min)
return output
def get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, tql_edit_pose_meta0, refer_edit_pose_meta):
for key, value in tpl_pose_meta0.items():
if type(value) is np.ndarray:
if key in ['keypoints_left_hand', 'keypoints_right_hand']:
value = value * np.array([[tpl_pose_meta0["width"], tpl_pose_meta0["height"], 1.0]])
if not isinstance(value, list):
value = value.tolist()
tpl_pose_meta0[key] = value
for key, value in refer_pose_meta.items():
if type(value) is np.ndarray:
if key in ['keypoints_left_hand', 'keypoints_right_hand']:
value = value * np.array([[refer_pose_meta["width"], refer_pose_meta["height"], 1.0]])
if not isinstance(value, list):
value = value.tolist()
refer_pose_meta[key] = value
tpl_pose_metas_new = []
for meta in tpl_pose_metas:
for key, value in meta.items():
if type(value) is np.ndarray:
if key in ['keypoints_left_hand', 'keypoints_right_hand']:
value = value * np.array([[meta["width"], meta["height"], 1.0]])
if not isinstance(value, list):
value = value.tolist()
meta[key] = value
tpl_pose_metas_new.append(meta)
if tql_edit_pose_meta0 is not None:
for key, value in tql_edit_pose_meta0.items():
if type(value) is np.ndarray:
if key in ['keypoints_left_hand', 'keypoints_right_hand']:
value = value * np.array([[tql_edit_pose_meta0["width"], tql_edit_pose_meta0["height"], 1.0]])
if not isinstance(value, list):
value = value.tolist()
tql_edit_pose_meta0[key] = value
if refer_edit_pose_meta is not None:
for key, value in refer_edit_pose_meta.items():
if type(value) is np.ndarray:
if key in ['keypoints_left_hand', 'keypoints_right_hand']:
value = value * np.array([[refer_edit_pose_meta["width"], refer_edit_pose_meta["height"], 1.0]])
if not isinstance(value, list):
value = value.tolist()
refer_edit_pose_meta[key] = value
retarget_tpl_pose_metas = retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas_new, tql_edit_pose_meta0, refer_edit_pose_meta)
pose_metas = []
for meta in retarget_tpl_pose_metas:
pose_meta = AAPoseMeta()
width, height = meta["width"], meta["height"]
pose_meta.width = width
pose_meta.height = height
pose_meta.kps_body = np.array(meta["keypoints_body"])[:, :2] * (width, height)
pose_meta.kps_body_p = np.array(meta["keypoints_body"])[:, 2]
kps_lhand = []
kps_lhand_p = []
for each_kps_lhand in meta["keypoints_left_hand"]:
if each_kps_lhand is not None:
kps_lhand.append([each_kps_lhand.x, each_kps_lhand.y])
kps_lhand_p.append(each_kps_lhand.score)
else:
kps_lhand.append([None, None])
kps_lhand_p.append(0.0)
pose_meta.kps_lhand = np.array(kps_lhand)
pose_meta.kps_lhand_p = np.array(kps_lhand_p)
kps_rhand = []
kps_rhand_p = []
for each_kps_rhand in meta["keypoints_right_hand"]:
if each_kps_rhand is not None:
kps_rhand.append([each_kps_rhand.x, each_kps_rhand.y])
kps_rhand_p.append(each_kps_rhand.score)
else:
kps_rhand.append([None, None])
kps_rhand_p.append(0.0)
pose_meta.kps_rhand = np.array(kps_rhand)
pose_meta.kps_rhand_p = np.array(kps_rhand_p)
pose_metas.append(pose_meta)
return pose_metas
# Copyright (c) 2025. Your modifications here.
# This file wraps and extends sam2.utils.misc for custom modifications.
from sam2.utils import misc as sam2_misc
from sam2.utils.misc import *
from PIL import Image
import numpy as np
import torch
from tqdm import tqdm
import os
import logging
import torch
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
from sam2.utils.misc import AsyncVideoFrameLoader, _load_img_as_tensor
from sam2.build_sam import _load_checkpoint
def _load_img_v2_as_tensor(img, image_size):
img_pil = Image.fromarray(img.astype(np.uint8))
img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
img_np = img_np / 255.0
else:
raise RuntimeError(f"Unknown image dtype: {img_np.dtype}")
img = torch.from_numpy(img_np).permute(2, 0, 1)
video_width, video_height = img_pil.size # the original video size
return img, video_height, video_width
def load_video_frames(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
frame_names=None,
):
"""
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
The frames are resized to image_size x image_size and are loaded to GPU if
`offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
You can load a frame asynchronously by setting `async_loading_frames` to `True`.
"""
if isinstance(video_path, str) and os.path.isdir(video_path):
jpg_folder = video_path
else:
raise NotImplementedError("Only JPEG frames are supported at this moment")
if frame_names is None:
frame_names = [
p
for p in os.listdir(jpg_folder)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
num_frames = len(frame_names)
if num_frames == 0:
raise RuntimeError(f"no images found in {jpg_folder}")
img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
if async_loading_frames:
lazy_images = AsyncVideoFrameLoader(
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
)
return lazy_images, lazy_images.video_height, lazy_images.video_width
images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
# normalize by mean and std
images -= img_mean
images /= img_std
return images, video_height, video_width
def load_video_frames_v2(
frames,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
frame_names=None,
):
"""
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
The frames are resized to image_size x image_size and are loaded to GPU if
`offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
You can load a frame asynchronously by setting `async_loading_frames` to `True`.
"""
num_frames = len(frames)
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, frame in enumerate(tqdm(frames, desc="video frame")):
images[n], video_height, video_width = _load_img_v2_as_tensor(frame, image_size)
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
# normalize by mean and std
images -= img_mean
images /= img_std
return images, video_height, video_width
def build_sam2_video_predictor(
config_file,
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
):
hydra_overrides = [
"++model._target_=video_predictor.SAM2VideoPredictor",
]
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
"++model.binarize_mask_from_pts_for_mem_enc=true",
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
"++model.fill_hole_area=8",
]
hydra_overrides.extend(hydra_overrides_extra)
# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides)
OmegaConf.resolve(cfg)
model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path)
model = model.to(device)
if mode == "eval":
model.eval()
return model
\ No newline at end of file
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import cv2
import math
import random
import numpy as np
def get_mask_boxes(mask):
"""
Args:
mask: [h, w]
Returns:
"""
y_coords, x_coords = np.nonzero(mask)
x_min = x_coords.min()
x_max = x_coords.max()
y_min = y_coords.min()
y_max = y_coords.max()
bbox = np.array([x_min, y_min, x_max, y_max]).astype(np.int32)
return bbox
def get_aug_mask(body_mask, w_len=10, h_len=20):
body_bbox = get_mask_boxes(body_mask)
bbox_wh = body_bbox[2:4] - body_bbox[0:2]
w_slice = np.int32(bbox_wh[0] / w_len)
h_slice = np.int32(bbox_wh[1] / h_len)
for each_w in range(body_bbox[0], body_bbox[2], w_slice):
w_start = min(each_w, body_bbox[2])
w_end = min((each_w + w_slice), body_bbox[2])
# print(w_start, w_end)
for each_h in range(body_bbox[1], body_bbox[3], h_slice):
h_start = min(each_h, body_bbox[3])
h_end = min((each_h + h_slice), body_bbox[3])
if body_mask[h_start:h_end, w_start:w_end].sum() > 0:
body_mask[h_start:h_end, w_start:w_end] = 1
return body_mask
def get_mask_body_img(img_copy, hand_mask, k=7, iterations=1):
kernel = np.ones((k, k), np.uint8)
dilation = cv2.dilate(hand_mask, kernel, iterations=iterations)
mask_hand_img = img_copy * (1 - dilation[:, :, None])
return mask_hand_img, dilation
def get_face_bboxes(kp2ds, scale, image_shape, ratio_aug):
h, w = image_shape
kp2ds_face = kp2ds.copy()[23:91, :2]
min_x, min_y = np.min(kp2ds_face, axis=0)
max_x, max_y = np.max(kp2ds_face, axis=0)
initial_width = max_x - min_x
initial_height = max_y - min_y
initial_area = initial_width * initial_height
expanded_area = initial_area * scale
new_width = np.sqrt(expanded_area * (initial_width / initial_height))
new_height = np.sqrt(expanded_area * (initial_height / initial_width))
delta_width = (new_width - initial_width) / 2
delta_height = (new_height - initial_height) / 4
if ratio_aug:
if random.random() > 0.5:
delta_width += random.uniform(0, initial_width // 10)
else:
delta_height += random.uniform(0, initial_height // 10)
expanded_min_x = max(min_x - delta_width, 0)
expanded_max_x = min(max_x + delta_width, w)
expanded_min_y = max(min_y - 3 * delta_height, 0)
expanded_max_y = min(max_y + delta_height, h)
return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)]
def calculate_new_size(orig_w, orig_h, target_area, divisor=64):
target_ratio = orig_w / orig_h
def check_valid(w, h):
if w <= 0 or h <= 0:
return False
return (w * h <= target_area and
w % divisor == 0 and
h % divisor == 0)
def get_ratio_diff(w, h):
return abs(w / h - target_ratio)
def round_to_64(value, round_up=False, divisor=64):
if round_up:
return divisor * ((value + (divisor - 1)) // divisor)
return divisor * (value // divisor)
possible_sizes = []
max_area_h = int(np.sqrt(target_area / target_ratio))
max_area_w = int(max_area_h * target_ratio)
max_h = round_to_64(max_area_h, round_up=True, divisor=divisor)
max_w = round_to_64(max_area_w, round_up=True, divisor=divisor)
for h in range(divisor, max_h + divisor, divisor):
ideal_w = h * target_ratio
w_down = round_to_64(ideal_w)
w_up = round_to_64(ideal_w, round_up=True)
for w in [w_down, w_up]:
if check_valid(w, h, divisor):
possible_sizes.append((w, h, get_ratio_diff(w, h)))
if not possible_sizes:
raise ValueError("Can not find suitable size")
possible_sizes.sort(key=lambda x: (-x[0] * x[1], x[2]))
best_w, best_h, _ = possible_sizes[0]
return int(best_w), int(best_h)
def resize_by_area(image, target_area, keep_aspect_ratio=True, divisor=64, padding_color=(0, 0, 0)):
h, w = image.shape[:2]
try:
new_w, new_h = calculate_new_size(w, h, target_area, divisor)
except:
aspect_ratio = w / h
if keep_aspect_ratio:
new_h = math.sqrt(target_area / aspect_ratio)
new_w = target_area / new_h
else:
new_w = new_h = math.sqrt(target_area)
new_w, new_h = int((new_w // divisor) * divisor), int((new_h // divisor) * divisor)
interpolation = cv2.INTER_AREA if (new_w * new_h < w * h) else cv2.INTER_LINEAR
resized_image = padding_resize(image, height=new_h, width=new_w, padding_color=padding_color,
interpolation=interpolation)
return resized_image
def padding_resize(img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR):
ori_height = img_ori.shape[0]
ori_width = img_ori.shape[1]
channel = img_ori.shape[2]
img_pad = np.zeros((height, width, channel))
if channel == 1:
img_pad[:, :, 0] = padding_color[0]
else:
img_pad[:, :, 0] = padding_color[0]
img_pad[:, :, 1] = padding_color[1]
img_pad[:, :, 2] = padding_color[2]
if (ori_height / ori_width) > (height / width):
new_width = int(height / ori_height * ori_width)
img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)
padding = int((width - new_width) / 2)
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
img_pad[:, padding: padding + new_width, :] = img
else:
new_height = int(width / ori_width * ori_height)
img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)
padding = int((height - new_height) / 2)
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
img_pad[padding: padding + new_height, :, :] = img
img_pad = np.uint8(img_pad)
return img_pad
def get_frame_indices(frame_num, video_fps, clip_length, train_fps):
start_frame = 0
times = np.arange(0, clip_length) / train_fps
frame_indices = start_frame + np.round(times * video_fps).astype(int)
frame_indices = np.clip(frame_indices, 0, frame_num - 1)
return frame_indices.tolist()
def get_face_bboxes(kp2ds, scale, image_shape):
h, w = image_shape
kp2ds_face = kp2ds.copy()[1:] * (w, h)
min_x, min_y = np.min(kp2ds_face, axis=0)
max_x, max_y = np.max(kp2ds_face, axis=0)
initial_width = max_x - min_x
initial_height = max_y - min_y
initial_area = initial_width * initial_height
expanded_area = initial_area * scale
new_width = np.sqrt(expanded_area * (initial_width / initial_height))
new_height = np.sqrt(expanded_area * (initial_height / initial_width))
delta_width = (new_width - initial_width) / 2
delta_height = (new_height - initial_height) / 4
expanded_min_x = max(min_x - delta_width, 0)
expanded_max_x = min(max_x + delta_width, w)
expanded_min_y = max(min_y - 3 * delta_height, 0)
expanded_max_y = min(max_y + delta_height, h)
return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment