Unverified Commit 39683e24 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files
parent d75d7720
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
"sample_guide_scale": 5.0, "sample_guide_scale": 5.0,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": false, "cpu_offload": false,
"src_pose_path": "/path/to/animate/process_results/src_pose.mp4", "src_pose_path": "../save_results/animate/process_results/src_pose.mp4",
"src_face_path": "/path/to/animate/process_results/src_face.mp4", "src_face_path": "../save_results/animate/process_results/src_face.mp4",
"src_ref_images": "/path/to/animate/process_results/src_ref.png", "src_ref_images": "../save_results/animate/process_results/src_ref.png",
"refert_num": 1, "refert_num": 1,
"replace_flag": false, "replace_flag": false,
"fps": 30 "fps": 30
......
{
"infer_steps": 20,
"target_video_length": 77,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"adapter_attn_type": "sage_attn2",
"sample_shift": 5.0,
"sample_guide_scale": 5.0,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "phase",
"src_pose_path": "../save_results/animate/process_results/src_pose.mp4",
"src_face_path": "../save_results/animate/process_results/src_face.mp4",
"src_ref_images": "../save_results/animate/process_results/src_ref.png",
"refert_num": 1,
"replace_flag": false,
"fps": 30,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F"
},
"t5_quantized": true,
"t5_quant_scheme": "fp8",
"clip_quantized": true,
"clip_quant_scheme": "fp8"
}
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
"sample_guide_scale": 5.0, "sample_guide_scale": 5.0,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": false, "cpu_offload": false,
"src_pose_path": "/path/to/replace/process_results/src_pose.mp4", "src_pose_path": "../save_results/replace/process_results/src_pose.mp4",
"src_face_path": "/path/to/replace/process_results/src_face.mp4", "src_face_path": "../save_results/replace/process_results/src_face.mp4",
"src_ref_images": "/path/to/replace/process_results/src_ref.png", "src_ref_images": "../save_results/replace/process_results/src_ref.png",
"src_bg_path": "/path/to/replace/process_results/src_bg.mp4", "src_bg_path": "../save_results/replace/process_results/src_bg.mp4",
"src_mask_path": "/path/to/replace/process_results/src_mask.mp4", "src_mask_path": "../save_results/replace/process_results/src_mask.mp4",
"refert_num": 1, "refert_num": 1,
"fps": 30, "fps": 30,
"replace_flag": true "replace_flag": true
......
{
"infer_steps": 20,
"target_video_length": 77,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"adapter_attn_type": "sage_attn2",
"sample_shift": 5.0,
"sample_guide_scale": 5.0,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "phase",
"src_pose_path": "../save_results/replace/process_results/src_pose.mp4",
"src_face_path": "../save_results/replace/process_results/src_face.mp4",
"src_ref_images": "../save_results/replace/process_results/src_ref.png",
"src_bg_path": "../save_results/replace/process_results/src_bg.mp4",
"src_mask_path": "../save_results/replace/process_results/src_mask.mp4",
"refert_num": 1,
"fps": 30,
"replace_flag": true,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F"
},
"t5_quantized": true,
"t5_quant_scheme": "fp8",
"clip_quantized": true,
"clip_quant_scheme": "fp8"
}
...@@ -311,7 +311,7 @@ class WanAnimateRunner(WanRunner): ...@@ -311,7 +311,7 @@ class WanAnimateRunner(WanRunner):
dtype=GET_DTYPE(), dtype=GET_DTYPE(),
) # c t h w ) # c t h w
else: else:
refer_t_pixel_values = self.gen_video[0, :, -self.config["refert_num"] :].transpose(0, 1).clone().detach() # c t h w refer_t_pixel_values = self.gen_video[0, :, -self.config["refert_num"] :].transpose(0, 1).clone().detach().cuda() # c t h w
bg_pixel_values, mask_pixel_values = None, None bg_pixel_values, mask_pixel_values = None, None
if self.config["replace_flag"] if "replace_flag" in self.config else False: if self.config["replace_flag"] if "replace_flag" in self.config else False:
......
...@@ -3,17 +3,28 @@ ...@@ -3,17 +3,28 @@
# set path and first # set path and first
lightx2v_path= lightx2v_path=
model_path= model_path=
video_path=
refer_path=
export CUDA_VISIBLE_DEVICES=7 export CUDA_VISIBLE_DEVICES=0
# set environment variables # set environment variables
source ${lightx2v_path}/scripts/base/base.sh source ${lightx2v_path}/scripts/base/base.sh
# process
python ${lightx2v_path}/tools/preprocess/preprocess_data.py \
--ckpt_path ${model_path}/process_checkpoint \
--video_path $video_path \
--refer_path $refer_path \
--save_path ${lightx2v_path}/save_results/animate/process_results \
--resolution_area 1280 720 \
--retarget_flag \
python -m lightx2v.infer \ python -m lightx2v.infer \
--model_cls wan2.2_animate \ --model_cls wan2.2_animate \
--task animate \ --task animate \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/wan22/wan_animate_replace.json \ --config_json ${lightx2v_path}/configs/wan22/wan_animate.json \
--prompt "视频中的人在做动作" \ --prompt "视频中的人在做动作" \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ --negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_animate.mp4 --save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_animate.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
video_path=
refer_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
# process
python ${lightx2v_path}/tools/preprocess/preprocess_data.py \
--ckpt_path ${model_path}/process_checkpoint \
--video_path $video_path \
--refer_path $refer_path \
--save_path ${lightx2v_path}/save_results/replace/process_results \
--resolution_area 1280 720 \
--iterations 3 \
--k 7 \
--w_len 1 \
--h_len 1 \
--replace_flag
python -m lightx2v.infer \
--model_cls wan2.2_animate \
--task animate \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan22/wan_animate_replace_4090.json \
--prompt "视频中的人在做动作" \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_replace.mp4
# Wan-animate Preprocessing User Guider
## 1. Introductions
Wan-animate offers two generation modes: `animation` and `replacement`. While both modes extract the skeleton from the reference video, they each have a distinct preprocessing pipeline.
### 1.1 Animation Mode
In this mode, it is highly recommended to enable pose retargeting, especially if the body proportions of the reference and driving characters are dissimilar.
- A simplified version of pose retargeting pipeline is provided to help developers quickly implement this functionality.
- **NOTE:** Due to the potential complexity of input data, the results from this simplified retargeting version are NOT guaranteed to be perfect. It is strongly advised to verify the preprocessing results before proceeding.
- Community contributions to improve on this feature are welcome.
### 1.2 Replacement Mode
- Pose retargeting is DISABLED by default in this mode. This is a deliberate choice to account for potential spatial interactions between the character and the environment.
- **WARNING**: If there is a significant mismatch in body proportions between the reference and driving characters, artifacts or deformations may appear in the final output.
- A simplified version for extracting the character's mask is also provided.
- **WARNING:** This mask extraction process is designed for **single-person videos ONLY** and may produce incorrect results or fail in multi-person videos (incorrect pose tracking). For multi-person video, users are required to either develop their own solution or integrate a suitable open-source tool.
---
## 2. Preprocessing Instructions and Recommendations
### 2.1 Basic Usage
- The preprocessing process requires some additional models, including pose detection (mandatory), and mask extraction and image editing models (optional, as needed). Place them according to the following directory structure:
```
/path/to/your/ckpt_path/
├── det/
│ └── yolov10m.onnx
├── pose2d/
│ └── vitpose_h_wholebody.onnx
├── sam2/
│ └── sam2_hiera_large.pt
└── FLUX.1-Kontext-dev/
```
- `video_path`, `refer_path`, and `save_path` correspond to the paths for the input driving video, the character image, and the preprocessed results.
- When using `animation` mode, two videos, `src_face.mp4` and `src_pose.mp4`, will be generated in `save_path`. When using `replacement` mode, two additional videos, `src_bg.mp4` and `src_mask.mp4`, will also be generated.
- The `resolution_area` parameter determines the resolution for both preprocessing and the generation model. Its size is determined by pixel area.
- The `fps` parameter can specify the frame rate for video processing. A lower frame rate can improve generation efficiency, but may cause stuttering or choppiness.
---
### 2.2 Animation Mode
- We support three forms: not using pose retargeting, using basic pose retargeting, and using enhanced pose retargeting based on the `FLUX.1-Kontext-dev` image editing model. These are specified via the `retarget_flag` and `use_flux` parameters.
- Specifying `retarget_flag` to use basic pose retargeting requires ensuring that both the reference character and the character in the first frame of the driving video are in a front-facing, stretched pose.
- Other than that, we recommend using enhanced pose retargeting by specifying both `retarget_flag` and `use_flux`. **NOTE:** Due to the limited capabilities of `FLUX.1-Kontext-dev`, it is NOT guaranteed to produce the expected results (e.g., consistency is not maintained, the pose is incorrect, etc.). It is recommended to check the intermediate results as well as the finally generated pose video; both are stored in `save_path`. Of course, users can also use a better image editing model, or explore the prompts for Flux on their own.
---
### 2.3 Replacement Mode
- Specifying `replace_flag` to enable data preprocessing for this mode. The preprocessing will additionally process a mask for the character in the video, and its size and shape can be adjusted by specifying some parameters.
- `iterations` and `k` can make the mask larger, covering more area.
- `w_len` and `h_len` can adjust the mask's shape. Smaller values will make the outline coarser, while larger values will make it finer.
- A smaller, finer-contoured mask can allow for more of the original background to be preserved, but may potentially limit the character's generation area (considering potential appearance differences, this can lead to some shape leakage). A larger, coarser mask can allow the character generation to be more flexible and consistent, but because it includes more of the background, it might affect the background's consistency. We recommend users to adjust the relevant parameters based on their specific input data.
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .process_pipepline import ProcessPipeline
from .video_predictor import SAM2VideoPredictor
This diff is collapsed.
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
from typing import List, Union
import cv2
import numpy as np
import onnxruntime
import torch
from pose2d_utils import bbox_from_detector, box_convert_simple, crop, keypoints_from_heatmaps, load_pose_metas_from_kp2ds_seq, read_img
class SimpleOnnxInference(object):
def __init__(self, checkpoint, device="cuda", reverse_input=False, **kwargs):
if isinstance(device, str):
device = torch.device(device)
if device.type == "cuda":
device = "{}:{}".format(device.type, device.index)
providers = [("CUDAExecutionProvider", {"device_id": device[-1:] if device[-1] in [str(_i) for _i in range(10)] else "0"}), "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
self.device = device
if not os.path.exists(checkpoint):
raise RuntimeError("{} is not existed!".format(checkpoint))
if os.path.isdir(checkpoint):
checkpoint = os.path.join(checkpoint, "end2end.onnx")
self.session = onnxruntime.InferenceSession(checkpoint, providers=providers)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
self.input_resolution = self.session.get_inputs()[0].shape[2:] if not reverse_input else self.session.get_inputs()[0].shape[2:][::-1]
self.input_resolution = np.array(self.input_resolution)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def get_output_names(self):
output_names = []
for node in self.session.get_outputs():
output_names.append(node.name)
return output_names
def set_device(self, device):
if isinstance(device, str):
device = torch.device(device)
if device.type == "cuda":
device = "{}:{}".format(device.type, device.index)
providers = [("CUDAExecutionProvider", {"device_id": device[-1:] if device[-1] in [str(_i) for _i in range(10)] else "0"}), "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
self.session.set_providers(providers)
self.device = device
class Yolo(SimpleOnnxInference):
def __init__(
self,
checkpoint,
device="cuda",
threshold_conf=0.05,
threshold_multi_persons=0.1,
input_resolution=(640, 640),
threshold_iou=0.5,
threshold_bbox_shape_ratio=0.4,
cat_id=[1],
select_type="max",
strict=True,
sorted_func=None,
**kwargs,
):
super(Yolo, self).__init__(checkpoint, device=device, **kwargs)
model_inputs = self.session.get_inputs()
input_shape = model_inputs[0].shape
self.input_width = 640
self.input_height = 640
self.threshold_multi_persons = threshold_multi_persons
self.threshold_conf = threshold_conf
self.threshold_iou = threshold_iou
self.threshold_bbox_shape_ratio = threshold_bbox_shape_ratio
self.input_resolution = input_resolution
self.cat_id = cat_id
self.select_type = select_type
self.strict = strict
self.sorted_func = sorted_func
def preprocess(self, input_image):
"""
Preprocesses the input image before performing inference.
Returns:
image_data: Preprocessed image data ready for inference.
"""
img = read_img(input_image)
# Get the height and width of the input image
img_height, img_width = img.shape[:2]
# Resize the image to match the input shape
img = cv2.resize(img, (self.input_resolution[1], self.input_resolution[0]))
# Normalize the image data by dividing it by 255.0
image_data = np.array(img) / 255.0
# Transpose the image to have the channel dimension as the first dimension
image_data = np.transpose(image_data, (2, 0, 1)) # Channel first
# Expand the dimensions of the image data to match the expected input shape
# image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
image_data = image_data.astype(np.float32)
# Return the preprocessed image data
return image_data, np.array([img_height, img_width])
def postprocess(self, output, shape_raw, cat_id=[1]):
"""
Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs.
Args:
input_image (numpy.ndarray): The input image.
output (numpy.ndarray): The output of the model.
Returns:
numpy.ndarray: The input image with detections drawn on it.
"""
# Transpose and squeeze the output to match the expected shape
outputs = np.squeeze(output)
if len(outputs.shape) == 1:
outputs = outputs[None]
if output.shape[-1] != 6 and output.shape[1] == 84:
outputs = np.transpose(outputs)
# Get the number of rows in the outputs array
rows = outputs.shape[0]
# Calculate the scaling factors for the bounding box coordinates
x_factor = shape_raw[1] / self.input_width
y_factor = shape_raw[0] / self.input_height
# Lists to store the bounding boxes, scores, and class IDs of the detections
boxes = []
scores = []
class_ids = []
if outputs.shape[-1] == 6:
max_scores = outputs[:, 4]
classid = outputs[:, -1]
threshold_conf_masks = max_scores >= self.threshold_conf
classid_masks = classid[threshold_conf_masks] != 3.14159
max_scores = max_scores[threshold_conf_masks][classid_masks]
classid = classid[threshold_conf_masks][classid_masks]
boxes = outputs[:, :4][threshold_conf_masks][classid_masks]
boxes[:, [0, 2]] *= x_factor
boxes[:, [1, 3]] *= y_factor
boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
boxes = boxes.astype(np.int32)
else:
classes_scores = outputs[:, 4:]
max_scores = np.amax(classes_scores, -1)
threshold_conf_masks = max_scores >= self.threshold_conf
classid = np.argmax(classes_scores[threshold_conf_masks], -1)
classid_masks = classid != 3.14159
classes_scores = classes_scores[threshold_conf_masks][classid_masks]
max_scores = max_scores[threshold_conf_masks][classid_masks]
classid = classid[classid_masks]
xywh = outputs[:, :4][threshold_conf_masks][classid_masks]
x = xywh[:, 0:1]
y = xywh[:, 1:2]
w = xywh[:, 2:3]
h = xywh[:, 3:4]
left = (x - w / 2) * x_factor
top = (y - h / 2) * y_factor
width = w * x_factor
height = h * y_factor
boxes = np.concatenate([left, top, width, height], axis=-1).astype(np.int32)
boxes = boxes.tolist()
scores = max_scores.tolist()
class_ids = classid.tolist()
# Apply non-maximum suppression to filter out overlapping bounding boxes
indices = cv2.dnn.NMSBoxes(boxes, scores, self.threshold_conf, self.threshold_iou)
# Iterate over the selected indices after non-maximum suppression
results = []
for i in indices:
# Get the box, score, and class ID corresponding to the index
box = box_convert_simple(boxes[i], "xywh2xyxy")
score = scores[i]
class_id = class_ids[i]
results.append(box + [score] + [class_id])
# # Draw the detection on the input image
# Return the modified input image
return np.array(results)
def process_results(self, results, shape_raw, cat_id=[1], single_person=True):
if isinstance(results, tuple):
det_results = results[0]
else:
det_results = results
person_results = []
person_count = 0
if len(results):
max_idx = -1
max_bbox_size = shape_raw[0] * shape_raw[1] * -10
max_bbox_shape = -1
bboxes = []
idx_list = []
for i in range(results.shape[0]):
bbox = results[i]
if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf):
idx_list.append(i)
bbox_shape = max((bbox[2] - bbox[0]), (bbox[3] - bbox[1]))
if bbox_shape > max_bbox_shape:
max_bbox_shape = bbox_shape
results = results[idx_list]
for i in range(results.shape[0]):
bbox = results[i]
bboxes.append(bbox)
if self.select_type == "max":
bbox_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
elif self.select_type == "center":
bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1] / 2)) * -1
bbox_shape = max((bbox[2] - bbox[0]), (bbox[3] - bbox[1]))
if bbox_size > max_bbox_size:
if (self.strict or max_idx != -1) and bbox_shape < max_bbox_shape * self.threshold_bbox_shape_ratio:
continue
max_bbox_size = bbox_size
max_bbox_shape = bbox_shape
max_idx = i
if self.sorted_func is not None and len(bboxes) > 0:
max_idx = self.sorted_func(bboxes, shape_raw)
bbox = bboxes[max_idx]
if self.select_type == "max":
max_bbox_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
elif self.select_type == "center":
max_bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1] / 2)) * -1
if max_idx != -1:
person_count = 1
if max_idx != -1:
person = {}
person["bbox"] = results[max_idx, :5]
person["track_id"] = int(0)
person_results.append(person)
for i in range(results.shape[0]):
bbox = results[i]
if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf):
if self.select_type == "max":
bbox_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
elif self.select_type == "center":
bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1] / 2)) * -1
if i != max_idx and bbox_size > max_bbox_size * self.threshold_multi_persons and bbox_size < max_bbox_size:
person_count += 1
if not single_person:
person = {}
person["bbox"] = results[i, :5]
person["track_id"] = int(person_count - 1)
person_results.append(person)
return person_results
else:
return None
def postprocess_threading(self, outputs, shape_raw, person_results, i, single_person=True, **kwargs):
result = self.postprocess(outputs[i], shape_raw[i], cat_id=self.cat_id)
result = self.process_results(result, shape_raw[i], cat_id=self.cat_id, single_person=single_person)
if result is not None and len(result) != 0:
person_results[i] = result
def forward(self, img, shape_raw, **kwargs):
"""
Performs inference using an ONNX model and returns the output image with drawn detections.
Returns:
output_img: The output image with drawn detections.
"""
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
shape_raw = shape_raw.cpu().numpy()
outputs = self.session.run(None, {self.session.get_inputs()[0].name: img})[0]
person_results = [[{"bbox": np.array([0.0, 0.0, 1.0 * shape_raw[i][1], 1.0 * shape_raw[i][0], -1]), "track_id": -1}] for i in range(len(outputs))]
for i in range(len(outputs)):
self.postprocess_threading(outputs, shape_raw, person_results, i, **kwargs)
return person_results
class ViTPose(SimpleOnnxInference):
def __init__(self, checkpoint, device="cuda", **kwargs):
super(ViTPose, self).__init__(checkpoint, device=device)
def forward(self, img, center, scale, **kwargs):
heatmaps = self.session.run([], {self.session.get_inputs()[0].name: img})[0]
points, prob = keypoints_from_heatmaps(heatmaps=heatmaps, center=center, scale=scale * 200, unbiased=True, use_udp=False)
return np.concatenate([points, prob], axis=2)
@staticmethod
def preprocess(img, bbox=None, input_resolution=(256, 192), rescale=1.25, mask=None, **kwargs):
if bbox is None or bbox[-1] <= 0 or (bbox[2] - bbox[0]) < 10 or (bbox[3] - bbox[1]) < 10:
bbox = np.array([0, 0, img.shape[1], img.shape[0]])
bbox_xywh = bbox
if mask is not None:
img = np.where(mask > 128, img, mask)
if isinstance(input_resolution, int):
center, scale = bbox_from_detector(bbox_xywh, (input_resolution, input_resolution), rescale=rescale)
img, new_shape, old_xy, new_xy = crop(img, center, scale, (input_resolution, input_resolution))
else:
center, scale = bbox_from_detector(bbox_xywh, input_resolution, rescale=rescale)
img, new_shape, old_xy, new_xy = crop(img, center, scale, (input_resolution[0], input_resolution[1]))
IMG_NORM_MEAN = np.array([0.485, 0.456, 0.406])
IMG_NORM_STD = np.array([0.229, 0.224, 0.225])
img_norm = (img / 255.0 - IMG_NORM_MEAN) / IMG_NORM_STD
img_norm = img_norm.transpose(2, 0, 1).astype(np.float32)
return img_norm, np.array(center), np.array(scale)
class Pose2d:
def __init__(self, checkpoint, detector_checkpoint=None, device="cuda", **kwargs):
if detector_checkpoint is not None:
self.detector = Yolo(detector_checkpoint, device)
else:
self.detector = None
self.model = ViTPose(checkpoint, device)
self.device = device
def load_images(self, inputs):
"""
Load images from various input types.
Args:
inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path,
single image array, or list of image arrays
Returns:
List[np.ndarray]: List of RGB image arrays
Raises:
ValueError: If file format is unsupported or image cannot be read
"""
if isinstance(inputs, str):
if inputs.lower().endswith((".mp4", ".avi", ".mov", ".mkv")):
cap = cv2.VideoCapture(inputs)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
images = frames
elif inputs.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
img = cv2.cvtColor(cv2.imread(inputs), cv2.COLOR_BGR2RGB)
if img is None:
raise ValueError(f"Cannot read image: {inputs}")
images = [img]
else:
raise ValueError(f"Unsupported file format: {inputs}")
elif isinstance(inputs, np.ndarray):
images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs]
elif isinstance(inputs, list):
images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs]
return images
def __call__(self, inputs: Union[str, np.ndarray, List[np.ndarray]], return_image: bool = False, **kwargs):
"""
Process input and estimate 2D keypoints.
Args:
inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path,
single image array, or list of image arrays
**kwargs: Additional arguments for processing
Returns:
np.ndarray: Array of detected 2D keypoints for all input images
"""
images = self.load_images(inputs)
H, W = images[0].shape[:2]
if self.detector is not None:
bboxes = []
for _image in images:
img, shape = self.detector.preprocess(_image)
bboxes.append(self.detector(img[None], shape[None])[0][0]["bbox"])
else:
bboxes = [None] * len(images)
kp2ds = []
for _image, _bbox in zip(images, bboxes):
img, center, scale = self.model.preprocess(_image, _bbox)
kp2ds.append(self.model(img[None], center[None], scale[None]))
kp2ds = np.concatenate(kp2ds, 0)
metas = load_pose_metas_from_kp2ds_seq(kp2ds, width=W, height=H)
return metas
This diff is collapsed.
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os
from process_pipepline import ProcessPipeline
def _parse_args():
parser = argparse.ArgumentParser(description="The preprocessing pipeline for Wan-animate.")
parser.add_argument("--ckpt_path", type=str, default=None, help="The path to the preprocessing model's checkpoint directory. ")
parser.add_argument("--video_path", type=str, default=None, help="The path to the driving video.")
parser.add_argument("--refer_path", type=str, default=None, help="The path to the refererence image.")
parser.add_argument("--save_path", type=str, default=None, help="The path to save the processed results.")
parser.add_argument(
"--resolution_area",
type=int,
nargs=2,
default=[1280, 720],
help="The target resolution for processing, specified as [width, height]. To handle different aspect ratios, the video is resized to have a total area equivalent to width * height, while preserving the original aspect ratio.",
)
parser.add_argument("--fps", type=int, default=30, help="The target FPS for processing the driving video. Set to -1 to use the video's original FPS.")
parser.add_argument("--replace_flag", action="store_true", default=False, help="Whether to use replacement mode.")
parser.add_argument("--retarget_flag", action="store_true", default=False, help="Whether to use pose retargeting. Currently only supported in animation mode")
parser.add_argument(
"--use_flux",
action="store_true",
default=False,
help="Whether to use image editing in pose retargeting. Recommended if the character in the reference image or the first frame of the driving video is not in a standard, front-facing pose",
)
# Parameters for the mask strategy in replacement mode. These control the mask's size and shape. Refer to https://arxiv.org/pdf/2502.06145
parser.add_argument("--iterations", type=int, default=3, help="Number of iterations for mask dilation.")
parser.add_argument("--k", type=int, default=7, help="Number of kernel size for mask dilation.")
parser.add_argument(
"--w_len",
type=int,
default=1,
help="The number of subdivisions for the grid along the 'w' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed.",
)
parser.add_argument(
"--h_len",
type=int,
default=1,
help="The number of subdivisions for the grid along the 'h' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed.",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = _parse_args()
args_dict = vars(args)
print(args_dict)
assert len(args.resolution_area) == 2, "resolution_area should be a list of two integers [width, height]"
assert not args.use_flux or args.retarget_flag, "Image editing with FLUX can only be used when pose retargeting is enabled."
pose2d_checkpoint_path = os.path.join(args.ckpt_path, "pose2d/vitpose_h_wholebody.onnx")
det_checkpoint_path = os.path.join(args.ckpt_path, "det/yolov10m.onnx")
sam2_checkpoint_path = os.path.join(args.ckpt_path, "sam2/sam2_hiera_large.pt") if args.replace_flag else None
flux_kontext_path = os.path.join(args.ckpt_path, "FLUX.1-Kontext-dev") if args.use_flux else None
process_pipeline = ProcessPipeline(
det_checkpoint_path=det_checkpoint_path, pose2d_checkpoint_path=pose2d_checkpoint_path, sam_checkpoint_path=sam2_checkpoint_path, flux_kontext_path=flux_kontext_path
)
os.makedirs(args.save_path, exist_ok=True)
process_pipeline(
video_path=args.video_path,
refer_image_path=args.refer_path,
output_path=args.save_path,
resolution_area=args.resolution_area,
fps=args.fps,
iterations=args.iterations,
k=args.k,
w_len=args.w_len,
h_len=args.h_len,
retarget_flag=args.retarget_flag,
use_flux=args.use_flux,
replace_flag=args.replace_flag,
)
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import shutil
import cv2
import numpy as np
import torch
from PIL import Image
from diffusers import FluxKontextPipeline
from loguru import logger
try:
import moviepy.editor as mpy
except: # noqa
import moviepy as mpy
import sam2.modeling.sam.transformer as transformer
from decord import VideoReader
from human_visualization import draw_aapose_by_meta_new
from pose2d import Pose2d
from pose2d_utils import AAPoseMeta
from retarget_pose import get_retarget_pose
from utils import get_aug_mask, get_face_bboxes, get_frame_indices, get_mask_body_img, padding_resize, resize_by_area
transformer.USE_FLASH_ATTN = False
transformer.MATH_KERNEL_ON = True
transformer.OLD_GPU = True
from sam_utils import build_sam2_video_predictor # noqa
class ProcessPipeline:
def __init__(self, det_checkpoint_path, pose2d_checkpoint_path, sam_checkpoint_path, flux_kontext_path):
self.pose2d = Pose2d(checkpoint=pose2d_checkpoint_path, detector_checkpoint=det_checkpoint_path)
model_cfg = "sam2_hiera_l.yaml"
if sam_checkpoint_path is not None:
self.predictor = build_sam2_video_predictor(model_cfg, sam_checkpoint_path)
if flux_kontext_path is not None:
self.flux_kontext = FluxKontextPipeline.from_pretrained(flux_kontext_path, torch_dtype=torch.bfloat16).to("cuda")
def __call__(self, video_path, refer_image_path, output_path, resolution_area=[1280, 720], fps=30, iterations=3, k=7, w_len=1, h_len=1, retarget_flag=False, use_flux=False, replace_flag=False):
if replace_flag:
video_reader = VideoReader(video_path)
frame_num = len(video_reader)
print("frame_num: {}".format(frame_num))
video_fps = video_reader.get_avg_fps()
print("video_fps: {}".format(video_fps))
print("fps: {}".format(fps))
# TODO: Maybe we can switch to PyAV later, which can get accurate frame num
duration = video_reader.get_frame_timestamp(-1)[-1]
expected_frame_num = int(duration * video_fps + 0.5)
ratio = abs((frame_num - expected_frame_num) / frame_num)
if ratio > 0.1:
print("Warning: The difference between the actual number of frames and the expected number of frames is two large")
frame_num = expected_frame_num
if fps == -1:
fps = video_fps
target_num = int(frame_num / video_fps * fps)
print("target_num: {}".format(target_num))
idxs = get_frame_indices(frame_num, video_fps, target_num, fps)
frames = video_reader.get_batch(idxs).asnumpy()
frames = [resize_by_area(frame, resolution_area[0] * resolution_area[1], divisor=16) for frame in frames]
height, width = frames[0].shape[:2]
logger.info(f"Processing pose meta")
tpl_pose_metas = self.pose2d(frames)
face_images = []
for idx, meta in enumerate(tpl_pose_metas):
face_bbox_for_image = get_face_bboxes(meta["keypoints_face"][:, :2], scale=1.3, image_shape=(frames[0].shape[0], frames[0].shape[1]))
x1, x2, y1, y2 = face_bbox_for_image
face_image = frames[idx][y1:y2, x1:x2]
face_image = cv2.resize(face_image, (512, 512))
face_images.append(face_image)
logger.info(f"Processing reference image: {refer_image_path}")
refer_img = cv2.imread(refer_image_path)
src_ref_path = os.path.join(output_path, "src_ref.png")
shutil.copy(refer_image_path, src_ref_path)
refer_img = refer_img[..., ::-1]
refer_img = padding_resize(refer_img, height, width)
logger.info(f"Processing template video: {video_path}")
tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas]
cond_images = []
for idx, meta in enumerate(tpl_retarget_pose_metas):
canvas = np.zeros_like(refer_img)
conditioning_image = draw_aapose_by_meta_new(canvas, meta)
cond_images.append(conditioning_image)
masks = self.get_mask(frames, 400, tpl_pose_metas)
bg_images = []
aug_masks = []
for frame, mask in zip(frames, masks):
if iterations > 0:
_, each_mask = get_mask_body_img(frame, mask, iterations=iterations, k=k)
each_aug_mask = get_aug_mask(each_mask, w_len=w_len, h_len=h_len)
else:
each_aug_mask = mask
each_bg_image = frame * (1 - each_aug_mask[:, :, None])
bg_images.append(each_bg_image)
aug_masks.append(each_aug_mask)
src_face_path = os.path.join(output_path, "src_face.mp4")
mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path)
src_pose_path = os.path.join(output_path, "src_pose.mp4")
mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path)
src_bg_path = os.path.join(output_path, "src_bg.mp4")
mpy.ImageSequenceClip(bg_images, fps=fps).write_videofile(src_bg_path)
aug_masks_new = [np.stack([mask * 255, mask * 255, mask * 255], axis=2) for mask in aug_masks]
src_mask_path = os.path.join(output_path, "src_mask.mp4")
mpy.ImageSequenceClip(aug_masks_new, fps=fps).write_videofile(src_mask_path)
return True
else:
logger.info(f"Processing reference image: {refer_image_path}")
refer_img = cv2.imread(refer_image_path)
src_ref_path = os.path.join(output_path, "src_ref.png")
shutil.copy(refer_image_path, src_ref_path)
refer_img = refer_img[..., ::-1]
refer_img = resize_by_area(refer_img, resolution_area[0] * resolution_area[1], divisor=16)
refer_pose_meta = self.pose2d([refer_img])[0]
logger.info(f"Processing template video: {video_path}")
video_reader = VideoReader(video_path)
frame_num = len(video_reader)
print("frame_num: {}".format(frame_num))
video_fps = video_reader.get_avg_fps()
print("video_fps: {}".format(video_fps))
print("fps: {}".format(fps))
# TODO: Maybe we can switch to PyAV later, which can get accurate frame num
duration = video_reader.get_frame_timestamp(-1)[-1]
expected_frame_num = int(duration * video_fps + 0.5)
ratio = abs((frame_num - expected_frame_num) / frame_num)
if ratio > 0.1:
print("Warning: The difference between the actual number of frames and the expected number of frames is two large")
frame_num = expected_frame_num
if fps == -1:
fps = video_fps
target_num = int(frame_num / video_fps * fps)
print("target_num: {}".format(target_num))
idxs = get_frame_indices(frame_num, video_fps, target_num, fps)
frames = video_reader.get_batch(idxs).asnumpy()
logger.info(f"Processing pose meta")
tpl_pose_meta0 = self.pose2d(frames[:1])[0]
tpl_pose_metas = self.pose2d(frames)
face_images = []
for idx, meta in enumerate(tpl_pose_metas):
face_bbox_for_image = get_face_bboxes(meta["keypoints_face"][:, :2], scale=1.3, image_shape=(frames[0].shape[0], frames[0].shape[1]))
x1, x2, y1, y2 = face_bbox_for_image
face_image = frames[idx][y1:y2, x1:x2]
face_image = cv2.resize(face_image, (512, 512))
face_images.append(face_image)
if retarget_flag:
if use_flux:
tpl_prompt, refer_prompt = self.get_editing_prompts(tpl_pose_metas, refer_pose_meta)
refer_input = Image.fromarray(refer_img)
refer_edit = self.flux_kontext(
image=refer_input,
height=refer_img.shape[0],
width=refer_img.shape[1],
prompt=refer_prompt,
guidance_scale=2.5,
num_inference_steps=28,
).images[0]
refer_edit = Image.fromarray(padding_resize(np.array(refer_edit), refer_img.shape[0], refer_img.shape[1]))
refer_edit_path = os.path.join(output_path, "refer_edit.png")
refer_edit.save(refer_edit_path)
refer_edit_pose_meta = self.pose2d([np.array(refer_edit)])[0]
tpl_img = frames[1]
tpl_input = Image.fromarray(tpl_img)
tpl_edit = self.flux_kontext(
image=tpl_input,
height=tpl_img.shape[0],
width=tpl_img.shape[1],
prompt=tpl_prompt,
guidance_scale=2.5,
num_inference_steps=28,
).images[0]
tpl_edit = Image.fromarray(padding_resize(np.array(tpl_edit), tpl_img.shape[0], tpl_img.shape[1]))
tpl_edit_path = os.path.join(output_path, "tpl_edit.png")
tpl_edit.save(tpl_edit_path)
tpl_edit_pose_meta0 = self.pose2d([np.array(tpl_edit)])[0]
tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, tpl_edit_pose_meta0, refer_edit_pose_meta)
else:
tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, None, None)
else:
tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas]
cond_images = []
for idx, meta in enumerate(tpl_retarget_pose_metas):
if retarget_flag:
canvas = np.zeros_like(refer_img)
conditioning_image = draw_aapose_by_meta_new(canvas, meta)
else:
canvas = np.zeros_like(frames[0])
conditioning_image = draw_aapose_by_meta_new(canvas, meta)
conditioning_image = padding_resize(conditioning_image, refer_img.shape[0], refer_img.shape[1])
cond_images.append(conditioning_image)
src_face_path = os.path.join(output_path, "src_face.mp4")
mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path)
src_pose_path = os.path.join(output_path, "src_pose.mp4")
mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path)
return True
def get_editing_prompts(self, tpl_pose_metas, refer_pose_meta):
arm_visible = False
leg_visible = False
for tpl_pose_meta in tpl_pose_metas:
tpl_keypoints = tpl_pose_meta["keypoints_body"]
if tpl_keypoints[3].all() != 0 or tpl_keypoints[4].all() != 0 or tpl_keypoints[6].all() != 0 or tpl_keypoints[7].all() != 0:
if (
(tpl_keypoints[3][0] <= 1 and tpl_keypoints[3][1] <= 1 and tpl_keypoints[3][2] >= 0.75)
or (tpl_keypoints[4][0] <= 1 and tpl_keypoints[4][1] <= 1 and tpl_keypoints[4][2] >= 0.75)
or (tpl_keypoints[6][0] <= 1 and tpl_keypoints[6][1] <= 1 and tpl_keypoints[6][2] >= 0.75)
or (tpl_keypoints[7][0] <= 1 and tpl_keypoints[7][1] <= 1 and tpl_keypoints[7][2] >= 0.75)
):
arm_visible = True
if tpl_keypoints[9].all() != 0 or tpl_keypoints[12].all() != 0 or tpl_keypoints[10].all() != 0 or tpl_keypoints[13].all() != 0:
if (
(tpl_keypoints[9][0] <= 1 and tpl_keypoints[9][1] <= 1 and tpl_keypoints[9][2] >= 0.75)
or (tpl_keypoints[12][0] <= 1 and tpl_keypoints[12][1] <= 1 and tpl_keypoints[12][2] >= 0.75)
or (tpl_keypoints[10][0] <= 1 and tpl_keypoints[10][1] <= 1 and tpl_keypoints[10][2] >= 0.75)
or (tpl_keypoints[13][0] <= 1 and tpl_keypoints[13][1] <= 1 and tpl_keypoints[13][2] >= 0.75)
):
leg_visible = True
if arm_visible and leg_visible:
break
if leg_visible:
if tpl_pose_meta["width"] > tpl_pose_meta["height"]:
tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image."
else:
tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image."
if refer_pose_meta["width"] > refer_pose_meta["height"]:
refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image."
else:
refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image."
elif arm_visible:
if tpl_pose_meta["width"] > tpl_pose_meta["height"]:
tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image."
else:
tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image."
if refer_pose_meta["width"] > refer_pose_meta["height"]:
refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image."
else:
refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image."
else:
tpl_prompt = "Change the person to face forward."
refer_prompt = "Change the person to face forward."
return tpl_prompt, refer_prompt
def get_mask(self, frames, th_step, kp2ds_all):
frame_num = len(frames)
if frame_num < th_step:
num_step = 1
else:
num_step = (frame_num + th_step) // th_step
all_mask = []
for index in range(num_step):
each_frames = frames[index * th_step : (index + 1) * th_step]
kp2ds = kp2ds_all[index * th_step : (index + 1) * th_step]
if len(each_frames) > 4:
key_frame_num = 4
elif 4 >= len(each_frames) > 0:
key_frame_num = 1
else:
continue
key_frame_step = len(kp2ds) // key_frame_num
key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))
key_points_index = [0, 1, 2, 5, 8, 11, 10, 13]
key_frame_body_points_list = []
for key_frame_index in key_frame_index_list:
keypoints_body_list = []
body_key_points = kp2ds[key_frame_index]["keypoints_body"]
for each_index in key_points_index:
each_keypoint = body_key_points[each_index]
if None is each_keypoint:
continue
keypoints_body_list.append(each_keypoint)
keypoints_body = np.array(keypoints_body_list)[:, :2]
wh = np.array([[kp2ds[0]["width"], kp2ds[0]["height"]]])
points = (keypoints_body * wh).astype(np.int32)
key_frame_body_points_list.append(points)
inference_state = self.predictor.init_state_v2(frames=each_frames)
self.predictor.reset_state(inference_state)
ann_obj_id = 1
for ann_frame_idx, points in zip(key_frame_index_list, key_frame_body_points_list):
labels = np.array([1] * points.shape[0], np.int32)
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids)}
for out_frame_idx in range(len(video_segments)):
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
out_mask = out_mask[0].astype(np.uint8)
all_mask.append(out_mask)
return all_mask
def convert_list_to_array(self, metas):
metas_list = []
for meta in metas:
for key, value in meta.items():
if type(value) is list:
value = np.array(value)
meta[key] = value
metas_list.append(meta)
return metas_list
This diff is collapsed.
# Copyright (c) 2025. Your modifications here.
# This file wraps and extends sam2.utils.misc for custom modifications.
import os
import numpy as np
import torch
from PIL import Image
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
from sam2.build_sam import _load_checkpoint
from sam2.utils.misc import *
from sam2.utils.misc import AsyncVideoFrameLoader, _load_img_as_tensor
from tqdm import tqdm
def _load_img_v2_as_tensor(img, image_size):
img_pil = Image.fromarray(img.astype(np.uint8))
img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
img_np = img_np / 255.0
else:
raise RuntimeError(f"Unknown image dtype: {img_np.dtype}")
img = torch.from_numpy(img_np).permute(2, 0, 1)
video_width, video_height = img_pil.size # the original video size
return img, video_height, video_width
def load_video_frames(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
frame_names=None,
):
"""
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
The frames are resized to image_size x image_size and are loaded to GPU if
`offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
You can load a frame asynchronously by setting `async_loading_frames` to `True`.
"""
if isinstance(video_path, str) and os.path.isdir(video_path):
jpg_folder = video_path
else:
raise NotImplementedError("Only JPEG frames are supported at this moment")
if frame_names is None:
frame_names = [p for p in os.listdir(jpg_folder) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"]]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
num_frames = len(frame_names)
if num_frames == 0:
raise RuntimeError(f"no images found in {jpg_folder}")
img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
if async_loading_frames:
lazy_images = AsyncVideoFrameLoader(img_paths, image_size, offload_video_to_cpu, img_mean, img_std)
return lazy_images, lazy_images.video_height, lazy_images.video_width
images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
# normalize by mean and std
images -= img_mean
images /= img_std
return images, video_height, video_width
def load_video_frames_v2(
frames,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
frame_names=None,
):
"""
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
The frames are resized to image_size x image_size and are loaded to GPU if
`offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
You can load a frame asynchronously by setting `async_loading_frames` to `True`.
"""
num_frames = len(frames)
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, frame in enumerate(tqdm(frames, desc="video frame")):
images[n], video_height, video_width = _load_img_v2_as_tensor(frame, image_size)
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
# normalize by mean and std
images -= img_mean
images /= img_std
return images, video_height, video_width
def build_sam2_video_predictor(
config_file,
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
):
hydra_overrides = [
"++model._target_=video_predictor.SAM2VideoPredictor",
]
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
"++model.binarize_mask_from_pts_for_mem_enc=true",
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
"++model.fill_hole_area=8",
]
hydra_overrides.extend(hydra_overrides_extra)
# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides)
OmegaConf.resolve(cfg)
model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path)
model = model.to(device)
if mode == "eval":
model.eval()
return model
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import random
import cv2
import numpy as np
def get_mask_boxes(mask):
"""
Args:
mask: [h, w]
Returns:
"""
y_coords, x_coords = np.nonzero(mask)
x_min = x_coords.min()
x_max = x_coords.max()
y_min = y_coords.min()
y_max = y_coords.max()
bbox = np.array([x_min, y_min, x_max, y_max]).astype(np.int32)
return bbox
def get_aug_mask(body_mask, w_len=10, h_len=20):
body_bbox = get_mask_boxes(body_mask)
bbox_wh = body_bbox[2:4] - body_bbox[0:2]
w_slice = np.int32(bbox_wh[0] / w_len)
h_slice = np.int32(bbox_wh[1] / h_len)
for each_w in range(body_bbox[0], body_bbox[2], w_slice):
w_start = min(each_w, body_bbox[2])
w_end = min((each_w + w_slice), body_bbox[2])
# print(w_start, w_end)
for each_h in range(body_bbox[1], body_bbox[3], h_slice):
h_start = min(each_h, body_bbox[3])
h_end = min((each_h + h_slice), body_bbox[3])
if body_mask[h_start:h_end, w_start:w_end].sum() > 0:
body_mask[h_start:h_end, w_start:w_end] = 1
return body_mask
def get_mask_body_img(img_copy, hand_mask, k=7, iterations=1):
kernel = np.ones((k, k), np.uint8)
dilation = cv2.dilate(hand_mask, kernel, iterations=iterations)
mask_hand_img = img_copy * (1 - dilation[:, :, None])
return mask_hand_img, dilation
def get_face_bboxes(kp2ds, scale, image_shape, ratio_aug):
h, w = image_shape
kp2ds_face = kp2ds.copy()[23:91, :2]
min_x, min_y = np.min(kp2ds_face, axis=0)
max_x, max_y = np.max(kp2ds_face, axis=0)
initial_width = max_x - min_x
initial_height = max_y - min_y
initial_area = initial_width * initial_height
expanded_area = initial_area * scale
new_width = np.sqrt(expanded_area * (initial_width / initial_height))
new_height = np.sqrt(expanded_area * (initial_height / initial_width))
delta_width = (new_width - initial_width) / 2
delta_height = (new_height - initial_height) / 4
if ratio_aug:
if random.random() > 0.5:
delta_width += random.uniform(0, initial_width // 10)
else:
delta_height += random.uniform(0, initial_height // 10)
expanded_min_x = max(min_x - delta_width, 0)
expanded_max_x = min(max_x + delta_width, w)
expanded_min_y = max(min_y - 3 * delta_height, 0)
expanded_max_y = min(max_y + delta_height, h)
return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)]
def calculate_new_size(orig_w, orig_h, target_area, divisor=64):
target_ratio = orig_w / orig_h
def check_valid(w, h):
if w <= 0 or h <= 0:
return False
return w * h <= target_area and w % divisor == 0 and h % divisor == 0
def get_ratio_diff(w, h):
return abs(w / h - target_ratio)
def round_to_64(value, round_up=False, divisor=64):
if round_up:
return divisor * ((value + (divisor - 1)) // divisor)
return divisor * (value // divisor)
possible_sizes = []
max_area_h = int(np.sqrt(target_area / target_ratio))
max_area_w = int(max_area_h * target_ratio)
max_h = round_to_64(max_area_h, round_up=True, divisor=divisor)
max_w = round_to_64(max_area_w, round_up=True, divisor=divisor)
for h in range(divisor, max_h + divisor, divisor):
ideal_w = h * target_ratio
w_down = round_to_64(ideal_w)
w_up = round_to_64(ideal_w, round_up=True)
for w in [w_down, w_up]:
if check_valid(w, h, divisor):
possible_sizes.append((w, h, get_ratio_diff(w, h)))
if not possible_sizes:
raise ValueError("Can not find suitable size")
possible_sizes.sort(key=lambda x: (-x[0] * x[1], x[2]))
best_w, best_h, _ = possible_sizes[0]
return int(best_w), int(best_h)
def resize_by_area(image, target_area, keep_aspect_ratio=True, divisor=64, padding_color=(0, 0, 0)):
h, w = image.shape[:2]
try:
new_w, new_h = calculate_new_size(w, h, target_area, divisor)
except: # noqa
aspect_ratio = w / h
if keep_aspect_ratio:
new_h = math.sqrt(target_area / aspect_ratio)
new_w = target_area / new_h
else:
new_w = new_h = math.sqrt(target_area)
new_w, new_h = int((new_w // divisor) * divisor), int((new_h // divisor) * divisor)
interpolation = cv2.INTER_AREA if (new_w * new_h < w * h) else cv2.INTER_LINEAR
resized_image = padding_resize(image, height=new_h, width=new_w, padding_color=padding_color, interpolation=interpolation)
return resized_image
def padding_resize(img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR):
ori_height = img_ori.shape[0]
ori_width = img_ori.shape[1]
channel = img_ori.shape[2]
img_pad = np.zeros((height, width, channel))
if channel == 1:
img_pad[:, :, 0] = padding_color[0]
else:
img_pad[:, :, 0] = padding_color[0]
img_pad[:, :, 1] = padding_color[1]
img_pad[:, :, 2] = padding_color[2]
if (ori_height / ori_width) > (height / width):
new_width = int(height / ori_height * ori_width)
img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)
padding = int((width - new_width) / 2)
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
img_pad[:, padding : padding + new_width, :] = img
else:
new_height = int(width / ori_width * ori_height)
img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)
padding = int((height - new_height) / 2)
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
img_pad[padding : padding + new_height, :, :] = img
img_pad = np.uint8(img_pad)
return img_pad
def get_frame_indices(frame_num, video_fps, clip_length, train_fps):
start_frame = 0
times = np.arange(0, clip_length) / train_fps
frame_indices = start_frame + np.round(times * video_fps).astype(int)
frame_indices = np.clip(frame_indices, 0, frame_num - 1)
return frame_indices.tolist()
def get_face_bboxes(kp2ds, scale, image_shape):
h, w = image_shape
kp2ds_face = kp2ds.copy()[1:] * (w, h)
min_x, min_y = np.min(kp2ds_face, axis=0)
max_x, max_y = np.max(kp2ds_face, axis=0)
initial_width = max_x - min_x
initial_height = max_y - min_y
initial_area = initial_width * initial_height
expanded_area = initial_area * scale
new_width = np.sqrt(expanded_area * (initial_width / initial_height))
new_height = np.sqrt(expanded_area * (initial_height / initial_width))
delta_width = (new_width - initial_width) / 2
delta_height = (new_height - initial_height) / 4
expanded_min_x = max(min_x - delta_width, 0)
expanded_max_x = min(max_x + delta_width, w)
expanded_min_y = max(min_y - 3 * delta_height, 0)
expanded_max_y = min(max_y + delta_height, h)
return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)]
# Copyright (c) 2025. Your modifications here.
# A wrapper for sam2 functions
from collections import OrderedDict
import torch
from sam2.sam2_video_predictor import SAM2VideoPredictor as _SAM2VideoPredictor
from sam_utils import load_video_frames, load_video_frames_v2
class SAM2VideoPredictor(_SAM2VideoPredictor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.inference_mode()
def init_state(self, video_path, offload_video_to_cpu=False, offload_state_to_cpu=False, async_loading_frames=False, frame_names=None):
"""Initialize a inference state."""
images, video_height, video_width = load_video_frames(
video_path=video_path, image_size=self.image_size, offload_video_to_cpu=offload_video_to_cpu, async_loading_frames=async_loading_frames, frame_names=frame_names
)
inference_state = {}
inference_state["images"] = images
inference_state["num_frames"] = len(images)
# whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
# whether to offload the inference state to CPU memory
# turning on this option saves the GPU memory at the cost of a lower tracking fps
# (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
# and from 24 to 21 when tracking two objects)
inference_state["offload_state_to_cpu"] = offload_state_to_cpu
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = torch.device("cuda")
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = torch.device("cuda")
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
# visual features on a small number of recently visited frames for quick interactions
inference_state["cached_features"] = {}
# values that don't change across frames (so we only need to hold one copy of them)
inference_state["constants"] = {}
# mapping between client-side object id and model-side object index
inference_state["obj_id_to_idx"] = OrderedDict()
inference_state["obj_idx_to_id"] = OrderedDict()
inference_state["obj_ids"] = []
# A storage to hold the model's tracking results and states on each frame
inference_state["output_dict"] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
inference_state["output_dict_per_obj"] = {}
# A temporary storage to hold new outputs when user interact with a frame
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
inference_state["temp_output_dict_per_obj"] = {}
# Frames that already holds consolidated outputs from click or mask inputs
# (we directly use their consolidated outputs during tracking)
inference_state["consolidated_frame_inds"] = {
"cond_frame_outputs": set(), # set containing frame indices
"non_cond_frame_outputs": set(), # set containing frame indices
}
# metadata for each tracking frame (e.g. which direction it's tracked)
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"] = {}
# Warm up the visual backbone and cache the image feature on frame 0
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state
@torch.inference_mode()
def init_state_v2(self, frames, offload_video_to_cpu=False, offload_state_to_cpu=False, async_loading_frames=False, frame_names=None):
"""Initialize a inference state."""
images, video_height, video_width = load_video_frames_v2(
frames=frames, image_size=self.image_size, offload_video_to_cpu=offload_video_to_cpu, async_loading_frames=async_loading_frames, frame_names=frame_names
)
inference_state = {}
inference_state["images"] = images
inference_state["num_frames"] = len(images)
# whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
# whether to offload the inference state to CPU memory
# turning on this option saves the GPU memory at the cost of a lower tracking fps
# (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
# and from 24 to 21 when tracking two objects)
inference_state["offload_state_to_cpu"] = offload_state_to_cpu
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = torch.device("cuda")
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = torch.device("cuda")
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
# visual features on a small number of recently visited frames for quick interactions
inference_state["cached_features"] = {}
# values that don't change across frames (so we only need to hold one copy of them)
inference_state["constants"] = {}
# mapping between client-side object id and model-side object index
inference_state["obj_id_to_idx"] = OrderedDict()
inference_state["obj_idx_to_id"] = OrderedDict()
inference_state["obj_ids"] = []
# A storage to hold the model's tracking results and states on each frame
inference_state["output_dict"] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
inference_state["output_dict_per_obj"] = {}
# A temporary storage to hold new outputs when user interact with a frame
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
inference_state["temp_output_dict_per_obj"] = {}
# Frames that already holds consolidated outputs from click or mask inputs
# (we directly use their consolidated outputs during tracking)
inference_state["consolidated_frame_inds"] = {
"cond_frame_outputs": set(), # set containing frame indices
"non_cond_frame_outputs": set(), # set containing frame indices
}
# metadata for each tracking frame (e.g. which direction it's tracked)
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"] = {}
inference_state["frames_tracked_per_obj"] = {}
# Warm up the visual backbone and cache the image feature on frame 0
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment