Commit 3af09475 authored by luopl's avatar luopl
Browse files

"Initial commit"

parents
Pipeline #3140 canceled with stages
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from collections import OrderedDict
import torch
from tqdm import tqdm
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
class SAM2VideoPredictor(SAM2Base):
"""The predictor class to handle user interactions and manage inference states."""
def __init__(
self,
fill_hole_area=0,
# whether to apply non-overlapping constraints on the output object masks
non_overlap_masks=False,
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
clear_non_cond_mem_around_input=False,
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
clear_non_cond_mem_for_multi_obj=False,
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
add_all_frames_to_correct_as_cond=False,
**kwargs,
):
super().__init__(**kwargs)
self.fill_hole_area = fill_hole_area
self.non_overlap_masks = non_overlap_masks
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
@torch.inference_mode()
def init_state(
self,
video_path,
offload_video_to_cpu=False,
offload_state_to_cpu=False,
async_loading_frames=False,
):
"""Initialize an inference state."""
compute_device = self.device # device of the model
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
compute_device=compute_device,
)
inference_state = {}
inference_state["images"] = images
inference_state["num_frames"] = len(images)
# whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
# whether to offload the inference state to CPU memory
# turning on this option saves the GPU memory at the cost of a lower tracking fps
# (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
# and from 24 to 21 when tracking two objects)
inference_state["offload_state_to_cpu"] = offload_state_to_cpu
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = compute_device
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = compute_device
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
# visual features on a small number of recently visited frames for quick interactions
inference_state["cached_features"] = {}
# values that don't change across frames (so we only need to hold one copy of them)
inference_state["constants"] = {}
# mapping between client-side object id and model-side object index
inference_state["obj_id_to_idx"] = OrderedDict()
inference_state["obj_idx_to_id"] = OrderedDict()
inference_state["obj_ids"] = []
# A storage to hold the model's tracking results and states on each frame
inference_state["output_dict"] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
inference_state["output_dict_per_obj"] = {}
# A temporary storage to hold new outputs when user interact with a frame
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
inference_state["temp_output_dict_per_obj"] = {}
# Frames that already holds consolidated outputs from click or mask inputs
# (we directly use their consolidated outputs during tracking)
inference_state["consolidated_frame_inds"] = {
"cond_frame_outputs": set(), # set containing frame indices
"non_cond_frame_outputs": set(), # set containing frame indices
}
# metadata for each tracking frame (e.g. which direction it's tracked)
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"] = {}
# Warm up the visual backbone and cache the image feature on frame 0
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state
@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
"""
Load a pretrained model from the Hugging Face hub.
Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.
Returns:
(SAM2VideoPredictor): The loaded model.
"""
from sam2.build_sam import build_sam2_video_predictor_hf
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
return sam_model
def _obj_id_to_idx(self, inference_state, obj_id):
"""Map client-side object id to model-side object index."""
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
if obj_idx is not None:
return obj_idx
# This is a new object id not sent to the server before. We only allow adding
# new objects *before* the tracking starts.
allow_new_object = not inference_state["tracking_has_started"]
if allow_new_object:
# get the next object slot
obj_idx = len(inference_state["obj_id_to_idx"])
inference_state["obj_id_to_idx"][obj_id] = obj_idx
inference_state["obj_idx_to_id"][obj_idx] = obj_id
inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
# set up input and output structures for this object
inference_state["point_inputs_per_obj"][obj_idx] = {}
inference_state["mask_inputs_per_obj"][obj_idx] = {}
inference_state["output_dict_per_obj"][obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
inference_state["temp_output_dict_per_obj"][obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
return obj_idx
else:
raise RuntimeError(
f"Cannot add new object id {obj_id} after tracking starts. "
f"All existing object ids: {inference_state['obj_ids']}. "
f"Please call 'reset_state' to restart from scratch."
)
def _obj_idx_to_id(self, inference_state, obj_idx):
"""Map model-side object index to client-side object id."""
return inference_state["obj_idx_to_id"][obj_idx]
def _get_obj_num(self, inference_state):
"""Get the total number of unique object ids received so far in this session."""
return len(inference_state["obj_idx_to_id"])
@torch.inference_mode()
def add_new_points_or_box(
self,
inference_state,
frame_idx,
obj_id,
points=None,
labels=None,
clear_old_points=True,
normalize_coords=True,
box=None,
):
"""Add new points to a frame."""
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
if (points is not None) != (labels is not None):
raise ValueError("points and labels must be provided together")
if points is None and box is None:
raise ValueError("at least one of points or box must be provided as input")
if points is None:
points = torch.zeros(0, 2, dtype=torch.float32)
elif not isinstance(points, torch.Tensor):
points = torch.tensor(points, dtype=torch.float32)
if labels is None:
labels = torch.zeros(0, dtype=torch.int32)
elif not isinstance(labels, torch.Tensor):
labels = torch.tensor(labels, dtype=torch.int32)
if points.dim() == 2:
points = points.unsqueeze(0) # add batch dimension
if labels.dim() == 1:
labels = labels.unsqueeze(0) # add batch dimension
# If `box` is provided, we add it as the first two points with labels 2 and 3
# along with the user-provided points (consistent with how SAM 2 is trained).
if box is not None:
if not clear_old_points:
raise ValueError(
"cannot add box without clearing old points, since "
"box prompt must be provided before any point prompt "
"(please use clear_old_points=True instead)"
)
if inference_state["tracking_has_started"]:
warnings.warn(
"You are adding a box after tracking starts. SAM 2 may not always be "
"able to incorporate a box prompt for *refinement*. If you intend to "
"use box prompt as an *initial* input before tracking, please call "
"'reset_state' on the inference state to restart from scratch.",
category=UserWarning,
stacklevel=2,
)
if not isinstance(box, torch.Tensor):
box = torch.tensor(box, dtype=torch.float32, device=points.device)
box_coords = box.reshape(1, 2, 2)
box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
box_labels = box_labels.reshape(1, 2)
points = torch.cat([box_coords, points], dim=1)
labels = torch.cat([box_labels, labels], dim=1)
if normalize_coords:
video_H = inference_state["video_height"]
video_W = inference_state["video_width"]
points = points / torch.tensor([video_W, video_H]).to(points.device)
# scale the (normalized) coordinates by the model's internal image size
points = points * self.image_size
points = points.to(inference_state["device"])
labels = labels.to(inference_state["device"])
if not clear_old_points:
point_inputs = point_inputs_per_frame.get(frame_idx, None)
else:
point_inputs = None
point_inputs = concat_points(point_inputs, points, labels)
point_inputs_per_frame[frame_idx] = point_inputs
mask_inputs_per_frame.pop(frame_idx, None)
# If this frame hasn't been tracked before, we treat it as an initial conditioning
# frame, meaning that the inputs points are to generate segments on this frame without
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
# the input points will be used to correct the already tracked masks.
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
# whether to track in reverse time order
if is_init_cond_frame:
reverse = False
else:
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# Add a frame to conditioning output if it's an initial conditioning frame or
# if the model sees all frames receiving clicks/mask as conditioning frames.
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Get any previously predicted mask logits on this object and feed it along with
# the new clicks into the SAM mask decoder.
prev_sam_mask_logits = None
# lookup temporary output dict first, which contains the most recent output
# (if not found, then lookup conditioning and non-conditioning frame output)
prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
if prev_out is None:
prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
if prev_out is None:
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
if prev_out is not None and prev_out["pred_masks"] is not None:
device = inference_state["device"]
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, _ = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=obj_output_dict, # run on the slice of a single object
frame_idx=frame_idx,
batch_size=1, # run on the slice of a single object
is_init_cond_frame=is_init_cond_frame,
point_inputs=point_inputs,
mask_inputs=None,
reverse=reverse,
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
# allows us to enforce non-overlapping constraints on all objects before encoding
# them into memory.
run_mem_encoder=False,
prev_sam_mask_logits=prev_sam_mask_logits,
)
# Add the output to the output dict (to be used as future memory)
obj_temp_output_dict[storage_key][frame_idx] = current_out
# Resize the output mask to the original video resolution
obj_ids = inference_state["obj_ids"]
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, consolidated_out["pred_masks_video_res"]
)
return frame_idx, obj_ids, video_res_masks
def add_new_points(self, *args, **kwargs):
"""Deprecated method. Please use `add_new_points_or_box` instead."""
return self.add_new_points_or_box(*args, **kwargs)
@torch.inference_mode()
def add_new_mask(
self,
inference_state,
frame_idx,
obj_id,
mask,
):
"""Add new mask to a frame."""
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
if not isinstance(mask, torch.Tensor):
mask = torch.tensor(mask, dtype=torch.bool)
assert mask.dim() == 2
mask_H, mask_W = mask.shape
mask_inputs_orig = mask[None, None] # add batch and channel dimension
mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
# resize the mask if it doesn't match the model's image size
if mask_H != self.image_size or mask_W != self.image_size:
mask_inputs = torch.nn.functional.interpolate(
mask_inputs_orig,
size=(self.image_size, self.image_size),
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
)
mask_inputs = (mask_inputs >= 0.5).float()
else:
mask_inputs = mask_inputs_orig
mask_inputs_per_frame[frame_idx] = mask_inputs
point_inputs_per_frame.pop(frame_idx, None)
# If this frame hasn't been tracked before, we treat it as an initial conditioning
# frame, meaning that the inputs points are to generate segments on this frame without
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
# the input points will be used to correct the already tracked masks.
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
# whether to track in reverse time order
if is_init_cond_frame:
reverse = False
else:
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# Add a frame to conditioning output if it's an initial conditioning frame or
# if the model sees all frames receiving clicks/mask as conditioning frames.
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
current_out, _ = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=obj_output_dict, # run on the slice of a single object
frame_idx=frame_idx,
batch_size=1, # run on the slice of a single object
is_init_cond_frame=is_init_cond_frame,
point_inputs=None,
mask_inputs=mask_inputs,
reverse=reverse,
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
# allows us to enforce non-overlapping constraints on all objects before encoding
# them into memory.
run_mem_encoder=False,
)
# Add the output to the output dict (to be used as future memory)
obj_temp_output_dict[storage_key][frame_idx] = current_out
# Resize the output mask to the original video resolution
obj_ids = inference_state["obj_ids"]
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, consolidated_out["pred_masks_video_res"]
)
return frame_idx, obj_ids, video_res_masks
def _get_orig_video_res_output(self, inference_state, any_res_masks):
"""
Resize the object scores to the original video resolution (video_res_masks)
and apply non-overlapping constraints for final output.
"""
device = inference_state["device"]
video_H = inference_state["video_height"]
video_W = inference_state["video_width"]
any_res_masks = any_res_masks.to(device, non_blocking=True)
if any_res_masks.shape[-2:] == (video_H, video_W):
video_res_masks = any_res_masks
else:
video_res_masks = torch.nn.functional.interpolate(
any_res_masks,
size=(video_H, video_W),
mode="bilinear",
align_corners=False,
)
if self.non_overlap_masks:
video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
return any_res_masks, video_res_masks
def _consolidate_temp_output_across_obj(
self,
inference_state,
frame_idx,
is_cond,
run_mem_encoder,
consolidate_at_video_res=False,
):
"""
Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
a frame into a single output for all objects, including
1) fill any missing objects either from `output_dict_per_obj` (if they exist in
`output_dict_per_obj` for this frame) or leave them as placeholder values
(if they don't exist in `output_dict_per_obj` for this frame);
2) if specified, rerun memory encoder after apply non-overlapping constraints
on the object scores.
"""
batch_size = self._get_obj_num(inference_state)
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Optionally, we allow consolidating the temporary outputs at the original
# video resolution (to provide a better editing experience for mask prompts).
if consolidate_at_video_res:
assert not run_mem_encoder, "memory encoder cannot run at video resolution"
consolidated_H = inference_state["video_height"]
consolidated_W = inference_state["video_width"]
consolidated_mask_key = "pred_masks_video_res"
else:
consolidated_H = consolidated_W = self.image_size // 4
consolidated_mask_key = "pred_masks"
# Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
# will be added when rerunning the memory encoder after applying non-overlapping
# constraints to object scores. Its "pred_masks" are prefilled with a large
# negative value (NO_OBJ_SCORE) to represent missing objects.
consolidated_out = {
"maskmem_features": None,
"maskmem_pos_enc": None,
consolidated_mask_key: torch.full(
size=(batch_size, 1, consolidated_H, consolidated_W),
fill_value=NO_OBJ_SCORE,
dtype=torch.float32,
device=inference_state["storage_device"],
),
"obj_ptr": torch.full(
size=(batch_size, self.hidden_dim),
fill_value=NO_OBJ_SCORE,
dtype=torch.float32,
device=inference_state["device"],
),
"object_score_logits": torch.full(
size=(batch_size, 1),
# default to 10.0 for object_score_logits, i.e. assuming the object is
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
fill_value=10.0,
dtype=torch.float32,
device=inference_state["device"],
),
}
empty_mask_ptr = None
for obj_idx in range(batch_size):
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
out = obj_temp_output_dict[storage_key].get(frame_idx, None)
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
# we fall back and look up its previous output in "output_dict_per_obj".
# We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
# "output_dict_per_obj" to find a previous output for this object.
if out is None:
out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
if out is None:
out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
# If the object doesn't appear in "output_dict_per_obj" either, we skip it
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
# placeholder above) and set its object pointer to be a dummy pointer.
if out is None:
# Fill in dummy object pointers for those objects without any inputs or
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
# i.e. when we need to build the memory for tracking).
if run_mem_encoder:
if empty_mask_ptr is None:
empty_mask_ptr = self._get_empty_mask_ptr(
inference_state, frame_idx
)
# fill object pointer with a dummy pointer (based on an empty mask)
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
continue
# Add the temporary object output mask to consolidated output mask
obj_mask = out["pred_masks"]
consolidated_pred_masks = consolidated_out[consolidated_mask_key]
if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
else:
# Resize first if temporary object mask has a different resolution
resized_obj_mask = torch.nn.functional.interpolate(
obj_mask,
size=consolidated_pred_masks.shape[-2:],
mode="bilinear",
align_corners=False,
)
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
"object_score_logits"
]
# Optionally, apply non-overlapping constraints on the consolidated scores
# and rerun the memory encoder
if run_mem_encoder:
device = inference_state["device"]
high_res_masks = torch.nn.functional.interpolate(
consolidated_out["pred_masks"].to(device, non_blocking=True),
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
if self.non_overlap_masks_for_mem_enc:
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
inference_state=inference_state,
frame_idx=frame_idx,
batch_size=batch_size,
high_res_masks=high_res_masks,
object_score_logits=consolidated_out["object_score_logits"],
is_mask_from_pts=True, # these frames are what the user interacted with
)
consolidated_out["maskmem_features"] = maskmem_features
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
return consolidated_out
def _get_empty_mask_ptr(self, inference_state, frame_idx):
"""Get a dummy object pointer based on an empty mask on the current frame."""
# A dummy (empty) mask with a single object
batch_size = 1
mask_inputs = torch.zeros(
(batch_size, 1, self.image_size, self.image_size),
dtype=torch.float32,
device=inference_state["device"],
)
# Retrieve correct image features
(
_,
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = self._get_image_feature(inference_state, frame_idx, batch_size)
# Feed the empty mask and image feature above to get a dummy object pointer
current_out = self.track_step(
frame_idx=frame_idx,
is_init_cond_frame=True,
current_vision_feats=current_vision_feats,
current_vision_pos_embeds=current_vision_pos_embeds,
feat_sizes=feat_sizes,
point_inputs=None,
mask_inputs=mask_inputs,
output_dict={},
num_frames=inference_state["num_frames"],
track_in_reverse=False,
run_mem_encoder=False,
prev_sam_mask_logits=None,
)
return current_out["obj_ptr"]
@torch.inference_mode()
def propagate_in_video_preflight(self, inference_state):
"""Prepare inference_state and consolidate temporary outputs before tracking."""
# Tracking has started and we don't allow adding new objects until session is reset.
inference_state["tracking_has_started"] = True
batch_size = self._get_obj_num(inference_state)
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
# add them into "output_dict".
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
output_dict = inference_state["output_dict"]
# "consolidated_frame_inds" contains indices of those frames where consolidated
# temporary outputs have been added (either in this call or any previous calls
# to `propagate_in_video_preflight`).
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
for is_cond in [False, True]:
# Separately consolidate conditioning and non-conditioning temp outputs
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Find all the frames that contain temporary outputs for any objects
# (these should be the frames that have just received clicks for mask inputs
# via `add_new_points_or_box` or `add_new_mask`)
temp_frame_inds = set()
for obj_temp_output_dict in temp_output_dict_per_obj.values():
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
consolidated_frame_inds[storage_key].update(temp_frame_inds)
# consolidate the temporary output across all objects on this frame
for frame_idx in temp_frame_inds:
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
)
# merge them into "output_dict" and also create per-object slices
output_dict[storage_key][frame_idx] = consolidated_out
self._add_output_per_object(
inference_state, frame_idx, consolidated_out, storage_key
)
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
if clear_non_cond_mem:
# clear non-conditioning memory of the surrounding frames
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
# clear temporary outputs in `temp_output_dict_per_obj`
for obj_temp_output_dict in temp_output_dict_per_obj.values():
obj_temp_output_dict[storage_key].clear()
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
# output on the same frame in "non_cond_frame_outputs"
for frame_idx in output_dict["cond_frame_outputs"]:
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
for obj_output_dict in inference_state["output_dict_per_obj"].values():
for frame_idx in obj_output_dict["cond_frame_outputs"]:
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
assert frame_idx in output_dict["cond_frame_outputs"]
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
# with either points or mask inputs (which should be true under a correct workflow).
all_consolidated_frame_inds = (
consolidated_frame_inds["cond_frame_outputs"]
| consolidated_frame_inds["non_cond_frame_outputs"]
)
input_frames_inds = set()
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
input_frames_inds.update(point_inputs_per_frame.keys())
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
input_frames_inds.update(mask_inputs_per_frame.keys())
assert all_consolidated_frame_inds == input_frames_inds
@torch.inference_mode()
def propagate_in_video(
self,
inference_state,
start_frame_idx=None,
max_frame_num_to_track=None,
reverse=False,
):
"""Propagate the input points across frames to track in the entire video."""
self.propagate_in_video_preflight(inference_state)
output_dict = inference_state["output_dict"]
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
obj_ids = inference_state["obj_ids"]
num_frames = inference_state["num_frames"]
batch_size = self._get_obj_num(inference_state)
if len(output_dict["cond_frame_outputs"]) == 0:
raise RuntimeError("No points are provided; please add points first")
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
# set start index, end index, and processing order
if start_frame_idx is None:
# default: start from the earliest frame with input points
start_frame_idx = min(output_dict["cond_frame_outputs"])
if max_frame_num_to_track is None:
# default: track all the frames in the video
max_frame_num_to_track = num_frames
if reverse:
end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
if start_frame_idx > 0:
processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
else:
processing_order = [] # skip reverse tracking if starting from frame 0
else:
end_frame_idx = min(
start_frame_idx + max_frame_num_to_track, num_frames - 1
)
processing_order = range(start_frame_idx, end_frame_idx + 1)
for frame_idx in tqdm(processing_order, desc="propagate in video"):
# We skip those frames already in consolidated outputs (these are frames
# that received input clicks or mask). Note that we cannot directly run
# batched forward on them via `_run_single_frame_inference` because the
# number of clicks on each object might be different.
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
storage_key = "cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
if clear_non_cond_mem:
# clear non-conditioning memory of the surrounding frames
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
storage_key = "non_cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
else:
storage_key = "non_cond_frame_outputs"
current_out, pred_masks = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=output_dict,
frame_idx=frame_idx,
batch_size=batch_size,
is_init_cond_frame=False,
point_inputs=None,
mask_inputs=None,
reverse=reverse,
run_mem_encoder=True,
)
output_dict[storage_key][frame_idx] = current_out
# Create slices of per-object outputs for subsequent interaction with each
# individual object after tracking.
self._add_output_per_object(
inference_state, frame_idx, current_out, storage_key
)
inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
# Resize the output mask to the original video resolution (we directly use
# the mask scores on GPU for output to avoid any CPU conversion in between)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, pred_masks
)
yield frame_idx, obj_ids, video_res_masks
def _add_output_per_object(
self, inference_state, frame_idx, current_out, storage_key
):
"""
Split a multi-object output into per-object output slices and add them into
`output_dict_per_obj`. The resulting slices share the same tensor storage.
"""
maskmem_features = current_out["maskmem_features"]
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
maskmem_pos_enc = current_out["maskmem_pos_enc"]
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
output_dict_per_obj = inference_state["output_dict_per_obj"]
for obj_idx, obj_output_dict in output_dict_per_obj.items():
obj_slice = slice(obj_idx, obj_idx + 1)
obj_out = {
"maskmem_features": None,
"maskmem_pos_enc": None,
"pred_masks": current_out["pred_masks"][obj_slice],
"obj_ptr": current_out["obj_ptr"][obj_slice],
"object_score_logits": current_out["object_score_logits"][obj_slice],
}
if maskmem_features is not None:
obj_out["maskmem_features"] = maskmem_features[obj_slice]
if maskmem_pos_enc is not None:
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
obj_output_dict[storage_key][frame_idx] = obj_out
@torch.inference_mode()
def clear_all_prompts_in_frame(
self, inference_state, frame_idx, obj_id, need_output=True
):
"""Remove all input points or mask in a specific frame for a given object."""
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
# Clear the conditioning information on the given frame
inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
# Check and see if there are still any inputs left on this frame
batch_size = self._get_obj_num(inference_state)
frame_has_input = False
for obj_idx2 in range(batch_size):
if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
frame_has_input = True
break
if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
frame_has_input = True
break
# If this frame has no remaining inputs for any objects, we further clear its
# conditioning frame status
if not frame_has_input:
output_dict = inference_state["output_dict"]
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
if out is not None:
# The frame is not a conditioning frame anymore since it's not receiving inputs,
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
output_dict["non_cond_frame_outputs"][frame_idx] = out
inference_state["frames_already_tracked"].pop(frame_idx, None)
# Similarly, do it for the sliced output on each object.
for obj_idx2 in range(batch_size):
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
if obj_out is not None:
obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
# If all the conditioning frames have been removed, we also clear the tracking outputs
if len(output_dict["cond_frame_outputs"]) == 0:
self._reset_tracking_results(inference_state)
if not need_output:
return
# Finally, output updated masks per object (after removing the inputs above)
obj_ids = inference_state["obj_ids"]
is_cond = any(
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
for obj_temp_output_dict in temp_output_dict_per_obj.values()
)
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, consolidated_out["pred_masks_video_res"]
)
return frame_idx, obj_ids, video_res_masks
@torch.inference_mode()
def reset_state(self, inference_state):
"""Remove all input points or mask in all frames throughout the video."""
self._reset_tracking_results(inference_state)
# Remove all object ids
inference_state["obj_id_to_idx"].clear()
inference_state["obj_idx_to_id"].clear()
inference_state["obj_ids"].clear()
inference_state["point_inputs_per_obj"].clear()
inference_state["mask_inputs_per_obj"].clear()
inference_state["output_dict_per_obj"].clear()
inference_state["temp_output_dict_per_obj"].clear()
def _reset_tracking_results(self, inference_state):
"""Reset all tracking inputs and results across the videos."""
for v in inference_state["point_inputs_per_obj"].values():
v.clear()
for v in inference_state["mask_inputs_per_obj"].values():
v.clear()
for v in inference_state["output_dict_per_obj"].values():
v["cond_frame_outputs"].clear()
v["non_cond_frame_outputs"].clear()
for v in inference_state["temp_output_dict_per_obj"].values():
v["cond_frame_outputs"].clear()
v["non_cond_frame_outputs"].clear()
inference_state["output_dict"]["cond_frame_outputs"].clear()
inference_state["output_dict"]["non_cond_frame_outputs"].clear()
inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"].clear()
def _get_image_feature(self, inference_state, frame_idx, batch_size):
"""Compute the image features on a given frame."""
# Look up in the cache first
image, backbone_out = inference_state["cached_features"].get(
frame_idx, (None, None)
)
if backbone_out is None:
# Cache miss -- we will run inference on a single image
device = inference_state["device"]
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
# expand the features to have the same dimension as the number of objects
expanded_image = image.expand(batch_size, -1, -1, -1)
expanded_backbone_out = {
"backbone_fpn": backbone_out["backbone_fpn"].copy(),
"vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
}
for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
expanded_backbone_out["backbone_fpn"][i] = feat.expand(
batch_size, -1, -1, -1
)
for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
pos = pos.expand(batch_size, -1, -1, -1)
expanded_backbone_out["vision_pos_enc"][i] = pos
features = self._prepare_backbone_features(expanded_backbone_out)
features = (expanded_image,) + features
return features
def _run_single_frame_inference(
self,
inference_state,
output_dict,
frame_idx,
batch_size,
is_init_cond_frame,
point_inputs,
mask_inputs,
reverse,
run_mem_encoder,
prev_sam_mask_logits=None,
):
"""Run tracking on a single frame based on current inputs and previous memory."""
# Retrieve correct image features
(
_,
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = self._get_image_feature(inference_state, frame_idx, batch_size)
# point and mask should not appear as input simultaneously on the same frame
assert point_inputs is None or mask_inputs is None
current_out = self.track_step(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats,
current_vision_pos_embeds=current_vision_pos_embeds,
feat_sizes=feat_sizes,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
output_dict=output_dict,
num_frames=inference_state["num_frames"],
track_in_reverse=reverse,
run_mem_encoder=run_mem_encoder,
prev_sam_mask_logits=prev_sam_mask_logits,
)
# optionally offload the output to CPU memory to save GPU space
storage_device = inference_state["storage_device"]
maskmem_features = current_out["maskmem_features"]
if maskmem_features is not None:
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
pred_masks_gpu = current_out["pred_masks"]
# potentially fill holes in the predicted masks
if self.fill_hole_area > 0:
pred_masks_gpu = fill_holes_in_mask_scores(
pred_masks_gpu, self.fill_hole_area
)
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
obj_ptr = current_out["obj_ptr"]
object_score_logits = current_out["object_score_logits"]
# make a compact version of this frame's output to reduce the state size
compact_current_out = {
"maskmem_features": maskmem_features,
"maskmem_pos_enc": maskmem_pos_enc,
"pred_masks": pred_masks,
"obj_ptr": obj_ptr,
"object_score_logits": object_score_logits,
}
return compact_current_out, pred_masks_gpu
def _run_memory_encoder(
self,
inference_state,
frame_idx,
batch_size,
high_res_masks,
object_score_logits,
is_mask_from_pts,
):
"""
Run the memory encoder on `high_res_masks`. This is usually after applying
non-overlapping constraints to object scores. Since their scores changed, their
memory also need to be computed again with the memory encoder.
"""
# Retrieve correct image features
_, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
inference_state, frame_idx, batch_size
)
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
current_vision_feats=current_vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_masks,
object_score_logits=object_score_logits,
is_mask_from_pts=is_mask_from_pts,
)
# optionally offload the output to CPU memory to save GPU space
storage_device = inference_state["storage_device"]
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
maskmem_pos_enc = self._get_maskmem_pos_enc(
inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
)
return maskmem_features, maskmem_pos_enc
def _get_maskmem_pos_enc(self, inference_state, current_out):
"""
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
a constant in the inference session to reduce session storage size.
"""
model_constants = inference_state["constants"]
# "out_maskmem_pos_enc" should be either a list of tensors or None
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
if out_maskmem_pos_enc is not None:
if "maskmem_pos_enc" not in model_constants:
assert isinstance(out_maskmem_pos_enc, list)
# only take the slice for one object, since it's same across objects
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
else:
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
# expand the cached maskmem_pos_enc to the actual batch size
batch_size = out_maskmem_pos_enc[0].size(0)
expanded_maskmem_pos_enc = [
x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
]
else:
expanded_maskmem_pos_enc = None
return expanded_maskmem_pos_enc
@torch.inference_mode()
def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
"""
Remove an object id from the tracking state. If strict is True, we check whether
the object id actually exists and raise an error if it doesn't exist.
"""
old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
updated_frames = []
# Check whether this object_id to remove actually exists and possibly raise an error.
if old_obj_idx_to_rm is None:
if not strict:
return inference_state["obj_ids"], updated_frames
raise RuntimeError(
f"Cannot remove object id {obj_id} as it doesn't exist. "
f"All existing object ids: {inference_state['obj_ids']}."
)
# If this is the only remaining object id, we simply reset the state.
if len(inference_state["obj_id_to_idx"]) == 1:
self.reset_state(inference_state)
return inference_state["obj_ids"], updated_frames
# There are still remaining objects after removing this object id. In this case,
# we need to delete the object storage from inference state tensors.
# Step 0: clear the input on those frames where this object id has point or mask input
# (note that this step is required as it might downgrade conditioning frames to
# non-conditioning ones)
obj_input_frames_inds = set()
obj_input_frames_inds.update(
inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
)
obj_input_frames_inds.update(
inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
)
for frame_idx in obj_input_frames_inds:
self.clear_all_prompts_in_frame(
inference_state, frame_idx, obj_id, need_output=False
)
# Step 1: Update the object id mapping (note that it must be done after Step 0,
# since Step 0 still requires the old object id mappings in inference_state)
old_obj_ids = inference_state["obj_ids"]
old_obj_inds = list(range(len(old_obj_ids)))
remain_old_obj_inds = old_obj_inds.copy()
remain_old_obj_inds.remove(old_obj_idx_to_rm)
new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
new_obj_inds = list(range(len(new_obj_ids)))
# build new mappings
old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
inference_state["obj_ids"] = new_obj_ids
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
# (note that "consolidated_frame_inds" doesn't need to be updated in this step as
# it's already handled in Step 0)
def _map_keys(container):
new_kvs = []
for k in old_obj_inds:
v = container.pop(k)
if k in old_idx_to_new_idx:
new_kvs.append((old_idx_to_new_idx[k], v))
container.update(new_kvs)
_map_keys(inference_state["point_inputs_per_obj"])
_map_keys(inference_state["mask_inputs_per_obj"])
_map_keys(inference_state["output_dict_per_obj"])
_map_keys(inference_state["temp_output_dict_per_obj"])
# Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
def _slice_state(output_dict, storage_key):
for frame_idx, out in output_dict[storage_key].items():
out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
out["maskmem_pos_enc"] = [
x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
]
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
out["object_score_logits"] = out["object_score_logits"][
remain_old_obj_inds
]
# also update the per-object slices
self._add_output_per_object(
inference_state, frame_idx, out, storage_key
)
_slice_state(inference_state["output_dict"], "cond_frame_outputs")
_slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
# Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
# could show an updated mask for objects previously occluded by the object being removed
if need_output:
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
for frame_idx in obj_input_frames_inds:
is_cond = any(
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
for obj_temp_output_dict in temp_output_dict_per_obj.values()
)
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, consolidated_out["pred_masks_video_res"]
)
updated_frames.append((frame_idx, video_res_masks))
return inference_state["obj_ids"], updated_frames
def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
"""
Remove the non-conditioning memory around the input frame. When users provide
correction clicks, the surrounding frames' non-conditioning memories can still
contain outdated object appearance information and could confuse the model.
This method clears those non-conditioning memories surrounding the interacted
frame to avoid giving the model both old and new information about the object.
"""
r = self.memory_temporal_stride_for_eval
frame_idx_begin = frame_idx - r * self.num_maskmem
frame_idx_end = frame_idx + r * self.num_maskmem
output_dict = inference_state["output_dict"]
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
for t in range(frame_idx_begin, frame_idx_end + 1):
non_cond_frame_outputs.pop(t, None)
for obj_output_dict in inference_state["output_dict_per_obj"].values():
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from copy import deepcopy
from itertools import product
from typing import Any, Dict, Generator, ItemsView, List, Tuple
import numpy as np
import torch
# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py
class MaskData:
"""
A structure for storing masks and their related data in batched format.
Implements basic filtering and concatenation.
"""
def __init__(self, **kwargs) -> None:
for v in kwargs.values():
assert isinstance(
v, (list, np.ndarray, torch.Tensor)
), "MaskData only supports list, numpy arrays, and torch tensors."
self._stats = dict(**kwargs)
def __setitem__(self, key: str, item: Any) -> None:
assert isinstance(
item, (list, np.ndarray, torch.Tensor)
), "MaskData only supports list, numpy arrays, and torch tensors."
self._stats[key] = item
def __delitem__(self, key: str) -> None:
del self._stats[key]
def __getitem__(self, key: str) -> Any:
return self._stats[key]
def items(self) -> ItemsView[str, Any]:
return self._stats.items()
def filter(self, keep: torch.Tensor) -> None:
for k, v in self._stats.items():
if v is None:
self._stats[k] = None
elif isinstance(v, torch.Tensor):
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
elif isinstance(v, np.ndarray):
self._stats[k] = v[keep.detach().cpu().numpy()]
elif isinstance(v, list) and keep.dtype == torch.bool:
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
elif isinstance(v, list):
self._stats[k] = [v[i] for i in keep]
else:
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
def cat(self, new_stats: "MaskData") -> None:
for k, v in new_stats.items():
if k not in self._stats or self._stats[k] is None:
self._stats[k] = deepcopy(v)
elif isinstance(v, torch.Tensor):
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
elif isinstance(v, np.ndarray):
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
elif isinstance(v, list):
self._stats[k] = self._stats[k] + deepcopy(v)
else:
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
def to_numpy(self) -> None:
for k, v in self._stats.items():
if isinstance(v, torch.Tensor):
self._stats[k] = v.float().detach().cpu().numpy()
def is_box_near_crop_edge(
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
) -> torch.Tensor:
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
return torch.any(near_crop_edge, dim=1)
def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
box_xywh = deepcopy(box_xyxy)
box_xywh[2] = box_xywh[2] - box_xywh[0]
box_xywh[3] = box_xywh[3] - box_xywh[1]
return box_xywh
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
assert len(args) > 0 and all(
len(a) == len(args[0]) for a in args
), "Batched iteration must have inputs of all the same size."
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches):
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
"""
Encodes masks to an uncompressed RLE, in the format expected by
pycoco tools.
"""
# Put in fortran order and flatten h,w
b, h, w = tensor.shape
tensor = tensor.permute(0, 2, 1).flatten(1)
# Compute change indices
diff = tensor[:, 1:] ^ tensor[:, :-1]
change_indices = diff.nonzero()
# Encode run length
out = []
for i in range(b):
cur_idxs = change_indices[change_indices[:, 0] == i, 1]
cur_idxs = torch.cat(
[
torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
cur_idxs + 1,
torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
]
)
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if tensor[i, 0] == 0 else [0]
counts.extend(btw_idxs.detach().cpu().tolist())
out.append({"size": [h, w], "counts": counts})
return out
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
"""Compute a binary mask from an uncompressed RLE."""
h, w = rle["size"]
mask = np.empty(h * w, dtype=bool)
idx = 0
parity = False
for count in rle["counts"]:
mask[idx : idx + count] = parity
idx += count
parity ^= True
mask = mask.reshape(w, h)
return mask.transpose() # Put in C order
def area_from_rle(rle: Dict[str, Any]) -> int:
return sum(rle["counts"][1::2])
def calculate_stability_score(
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
) -> torch.Tensor:
"""
Computes the stability score for a batch of masks. The stability
score is the IoU between the binary masks obtained by thresholding
the predicted mask logits at high and low values.
"""
# One mask is always contained inside the other.
# Save memory by preventing unnecessary cast to torch.int64
intersections = (
(masks > (mask_threshold + threshold_offset))
.sum(-1, dtype=torch.int16)
.sum(-1, dtype=torch.int32)
)
unions = (
(masks > (mask_threshold - threshold_offset))
.sum(-1, dtype=torch.int16)
.sum(-1, dtype=torch.int32)
)
return intersections / unions
def build_point_grid(n_per_side: int) -> np.ndarray:
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
offset = 1 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
return points
def build_all_layer_point_grids(
n_per_side: int, n_layers: int, scale_per_layer: int
) -> List[np.ndarray]:
"""Generates point grids for all crop layers."""
points_by_layer = []
for i in range(n_layers + 1):
n_points = int(n_per_side / (scale_per_layer**i))
points_by_layer.append(build_point_grid(n_points))
return points_by_layer
def generate_crop_boxes(
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
) -> Tuple[List[List[int]], List[int]]:
"""
Generates a list of crop boxes of different sizes. Each layer
has (2**i)**2 boxes for the ith layer.
"""
crop_boxes, layer_idxs = [], []
im_h, im_w = im_size
short_side = min(im_h, im_w)
# Original image
crop_boxes.append([0, 0, im_w, im_h])
layer_idxs.append(0)
def crop_len(orig_len, n_crops, overlap):
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
for i_layer in range(n_layers):
n_crops_per_side = 2 ** (i_layer + 1)
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
crop_w = crop_len(im_w, n_crops_per_side, overlap)
crop_h = crop_len(im_h, n_crops_per_side, overlap)
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
# Crops in XYWH format
for x0, y0 in product(crop_box_x0, crop_box_y0):
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
crop_boxes.append(box)
layer_idxs.append(i_layer + 1)
return crop_boxes, layer_idxs
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
# Check if boxes has a channel dimension
if len(boxes.shape) == 3:
offset = offset.unsqueeze(1)
return boxes + offset
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0]], device=points.device)
# Check if points has a channel dimension
if len(points.shape) == 3:
offset = offset.unsqueeze(1)
return points + offset
def uncrop_masks(
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
) -> torch.Tensor:
x0, y0, x1, y1 = crop_box
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
return masks
# Coordinate transform masks
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
pad = (x0, pad_x - x0, y0, pad_y - y0)
return torch.nn.functional.pad(masks, pad, value=0)
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
from pycocotools import mask as mask_utils # type: ignore
h, w = uncompressed_rle["size"]
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
return rle
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
"""
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
"""
# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
# Normalize shape to CxHxW
shape = masks.shape
h, w = shape[-2:]
if len(shape) > 2:
masks = masks.flatten(0, -3)
else:
masks = masks.unsqueeze(0)
# Get top and bottom edges
in_height, _ = torch.max(masks, dim=-1)
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
in_height_coords = in_height_coords + h * (~in_height)
top_edges, _ = torch.min(in_height_coords, dim=-1)
# Get left and right edges
in_width, _ = torch.max(masks, dim=-2)
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
right_edges, _ = torch.max(in_width_coords, dim=-1)
in_width_coords = in_width_coords + w * (~in_width)
left_edges, _ = torch.min(in_width_coords, dim=-1)
# If the mask is empty the right edge will be to the left of the left edge.
# Replace these boxes with [0, 0, 0, 0]
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
out = out * (~empty_filter).unsqueeze(-1)
# Return to original shape
if len(shape) > 2:
out = out.reshape(*shape[:-2], 4)
else:
out = out[0]
return out
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import warnings
from threading import Thread
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
def get_sdpa_settings():
if torch.cuda.is_available():
old_gpu = torch.cuda.get_device_properties(0).major < 7
# only use Flash Attention on Ampere (8.0) or newer GPUs
use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
if not use_flash_attn:
warnings.warn(
"Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
category=UserWarning,
stacklevel=2,
)
# keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
# available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
if pytorch_version < (2, 2):
warnings.warn(
f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
"Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
category=UserWarning,
stacklevel=2,
)
math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
else:
old_gpu = True
use_flash_attn = False
math_kernel_on = True
return old_gpu, use_flash_attn, math_kernel_on
def get_connected_components(mask):
"""
Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
Inputs:
- mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
background.
Outputs:
- labels: A tensor of shape (N, 1, H, W) containing the connected component labels
for foreground pixels and 0 for background pixels.
- counts: A tensor of shape (N, 1, H, W) containing the area of the connected
components for foreground pixels and 0 for background pixels.
"""
from sam2 import _C
return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
def mask_to_box(masks: torch.Tensor):
"""
compute bounding box given an input mask
Inputs:
- masks: [B, 1, H, W] masks, dtype=torch.Tensor
Returns:
- box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
"""
B, _, h, w = masks.shape
device = masks.device
xs = torch.arange(w, device=device, dtype=torch.int32)
ys = torch.arange(h, device=device, dtype=torch.int32)
grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
return bbox_coords
def _load_img_as_tensor(img_path, image_size):
img_pil = Image.open(img_path)
img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
img_np = img_np / 255.0
else:
raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
img = torch.from_numpy(img_np).permute(2, 0, 1)
video_width, video_height = img_pil.size # the original video size
return img, video_height, video_width
class AsyncVideoFrameLoader:
"""
A list of video frames to be load asynchronously without blocking session start.
"""
def __init__(
self,
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
):
self.img_paths = img_paths
self.image_size = image_size
self.offload_video_to_cpu = offload_video_to_cpu
self.img_mean = img_mean
self.img_std = img_std
# items in `self.images` will be loaded asynchronously
self.images = [None] * len(img_paths)
# catch and raise any exceptions in the async loading thread
self.exception = None
# video_height and video_width be filled when loading the first image
self.video_height = None
self.video_width = None
self.compute_device = compute_device
# load the first frame to fill video_height and video_width and also
# to cache it (since it's most likely where the user will click)
self.__getitem__(0)
# load the rest of frames asynchronously without blocking the session start
def _load_frames():
try:
for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
self.__getitem__(n)
except Exception as e:
self.exception = e
self.thread = Thread(target=_load_frames, daemon=True)
self.thread.start()
def __getitem__(self, index):
if self.exception is not None:
raise RuntimeError("Failure in frame loading thread") from self.exception
img = self.images[index]
if img is not None:
return img
img, video_height, video_width = _load_img_as_tensor(
self.img_paths[index], self.image_size
)
self.video_height = video_height
self.video_width = video_width
# normalize by mean and std
img -= self.img_mean
img /= self.img_std
if not self.offload_video_to_cpu:
img = img.to(self.compute_device, non_blocking=True)
self.images[index] = img
return img
def __len__(self):
return len(self.images)
def load_video_frames(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=torch.device("cuda"),
):
"""
Load the video frames from video_path. The frames are resized to image_size as in
the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo.
"""
is_bytes = isinstance(video_path, bytes)
is_str = isinstance(video_path, str)
is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"]
if is_bytes or is_mp4_path:
return load_video_frames_from_video_file(
video_path=video_path,
image_size=image_size,
offload_video_to_cpu=offload_video_to_cpu,
img_mean=img_mean,
img_std=img_std,
compute_device=compute_device,
)
elif is_str and os.path.isdir(video_path):
return load_video_frames_from_jpg_images(
video_path=video_path,
image_size=image_size,
offload_video_to_cpu=offload_video_to_cpu,
img_mean=img_mean,
img_std=img_std,
async_loading_frames=async_loading_frames,
compute_device=compute_device,
)
else:
raise NotImplementedError(
"Only MP4 video and JPEG folder are supported at this moment"
)
def load_video_frames_from_jpg_images(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=torch.device("cuda"),
):
"""
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
The frames are resized to image_size x image_size and are loaded to GPU if
`offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
You can load a frame asynchronously by setting `async_loading_frames` to `True`.
"""
if isinstance(video_path, str) and os.path.isdir(video_path):
jpg_folder = video_path
else:
raise NotImplementedError(
"Only JPEG frames are supported at this moment. For video files, you may use "
"ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
"```\n"
"ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'\n"
"```\n"
"where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
"ffmpeg to start the JPEG file from 00000.jpg."
)
frame_names = [
p
for p in os.listdir(jpg_folder)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
num_frames = len(frame_names)
if num_frames == 0:
raise RuntimeError(f"no images found in {jpg_folder}")
img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
if async_loading_frames:
lazy_images = AsyncVideoFrameLoader(
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
)
return lazy_images, lazy_images.video_height, lazy_images.video_width
images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.to(compute_device)
img_mean = img_mean.to(compute_device)
img_std = img_std.to(compute_device)
# normalize by mean and std
images -= img_mean
images /= img_std
return images, video_height, video_width
def load_video_frames_from_video_file(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
compute_device=torch.device("cuda"),
):
"""Load the video frames from a video file."""
import decord
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
# Get the original video height and width
decord.bridge.set_bridge("torch")
video_height, video_width, _ = decord.VideoReader(video_path).next().shape
# Iterate over all frames in the video
images = []
for frame in decord.VideoReader(video_path, width=image_size, height=image_size):
images.append(frame.permute(2, 0, 1))
images = torch.stack(images, dim=0).float() / 255.0
if not offload_video_to_cpu:
images = images.to(compute_device)
img_mean = img_mean.to(compute_device)
img_std = img_std.to(compute_device)
# normalize by mean and std
images -= img_mean
images /= img_std
return images, video_height, video_width
def fill_holes_in_mask_scores(mask, max_area):
"""
A post processor to fill small holes in mask scores with area under `max_area`.
"""
# Holes are those connected components in background with area <= self.max_area
# (background regions are those with mask scores <= 0)
assert max_area > 0, "max_area must be positive"
input_mask = mask
try:
labels, areas = get_connected_components(mask <= 0)
is_hole = (labels > 0) & (areas <= max_area)
# We fill holes with a small positive mask score (0.1) to change them to foreground.
mask = torch.where(is_hole, 0.1, mask)
except Exception as e:
# Skip the post-processing step on removing small holes if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
"functionality may be limited (which doesn't affect the results in most cases; see "
"https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
)
mask = input_mask
return mask
def concat_points(old_point_inputs, new_points, new_labels):
"""Add new points and labels to previous point inputs (add at the end)."""
if old_point_inputs is None:
points, labels = new_points, new_labels
else:
points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
return {"point_coords": points, "point_labels": labels}
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Normalize, Resize, ToTensor
class SAM2Transforms(nn.Module):
def __init__(
self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
):
"""
Transforms for SAM2.
"""
super().__init__()
self.resolution = resolution
self.mask_threshold = mask_threshold
self.max_hole_area = max_hole_area
self.max_sprinkle_area = max_sprinkle_area
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
self.to_tensor = ToTensor()
self.transforms = torch.jit.script(
nn.Sequential(
Resize((self.resolution, self.resolution)),
Normalize(self.mean, self.std),
)
)
def __call__(self, x):
x = self.to_tensor(x)
return self.transforms(x)
def forward_batch(self, img_list):
img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
img_batch = torch.stack(img_batch, dim=0)
return img_batch
def transform_coords(
self, coords: torch.Tensor, normalize=False, orig_hw=None
) -> torch.Tensor:
"""
Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
Returns
Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
"""
if normalize:
assert orig_hw is not None
h, w = orig_hw
coords = coords.clone()
coords[..., 0] = coords[..., 0] / w
coords[..., 1] = coords[..., 1] / h
coords = coords * self.resolution # unnormalize coords
return coords
def transform_boxes(
self, boxes: torch.Tensor, normalize=False, orig_hw=None
) -> torch.Tensor:
"""
Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
"""
boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
return boxes
def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
"""
Perform PostProcessing on output masks.
"""
from sam2.utils.misc import get_connected_components
masks = masks.float()
input_masks = masks
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
try:
if self.max_hole_area > 0:
# Holes are those connected components in background with area <= self.fill_hole_area
# (background regions are those with mask scores <= self.mask_threshold)
labels, areas = get_connected_components(
mask_flat <= self.mask_threshold
)
is_hole = (labels > 0) & (areas <= self.max_hole_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with a small positive mask score (10.0) to change them to foreground.
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
if self.max_sprinkle_area > 0:
labels, areas = get_connected_components(
mask_flat > self.mask_threshold
)
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with negative mask score (-10.0) to change them to background.
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
except Exception as e:
# Skip the post-processing step if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
"functionality may be limited (which doesn't affect the results in most cases; see "
"https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
)
masks = input_masks
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
return masks
BSD License
For SAM 2 Eval software
Copyright (c) Meta Platforms, Inc. and affiliates.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name Meta nor the names of its contributors may be used to
endorse or promote products derived from this software without specific
prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
BSD 3-Clause License
Copyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
Copyright 2023 Rex Cheng
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
# Segment Anything Video (SA-V) Dataset
## Overview
[Segment Anything Video (SA-V)](https://ai.meta.com/datasets/segment-anything-video/), consists of 51K diverse videos and 643K high-quality spatio-temporal segmentation masks (i.e., masklets). The dataset is released under the CC by 4.0 license. Browse the dataset [here](https://sam2.metademolab.com/dataset).
![SA-V dataset](../assets/sa_v_dataset.jpg?raw=true)
## Getting Started
### Download the dataset
Visit [here](https://ai.meta.com/datasets/segment-anything-video-downloads/) to download SA-V including the training, val and test sets.
### Dataset Stats
| | Num Videos | Num Masklets |
| ---------- | ---------- | ----------------------------------------- |
| SA-V train | 50,583 | 642,036 (auto 451,720 and manual 190,316) |
| SA-V val | 155 | 293 |
| SA-V test | 150 | 278 |
### Notebooks
To load and visualize the SA-V training set annotations, refer to the example [sav_visualization_example.ipynb](./sav_visualization_example.ipynb) notebook.
### SA-V train
For SA-V training set we release the mp4 videos and store the masklet annotations per video as json files . Automatic masklets and manual masklets are stored separately as two json files: `{video_id}_auto.json` and `{video_id}_manual.json`. They can be loaded as dictionaries in python in the format below.
```
{
"video_id" : str; video id
"video_duration" : float64; the duration in seconds of this video
"video_frame_count" : float64; the number of frames in the video
"video_height" : float64; the height of the video
"video_width" : float64; the width of the video
"video_resolution" : float64; video_height $\times$ video_width
"video_environment" : List[str]; "Indoor" or "Outdoor"
"video_split" : str; "train" for training set
"masklet" : List[List[Dict]]; masklet annotations in list of list of RLEs.
The outer list is over frames in the video and the inner list
is over objects in the video.
"masklet_id" : List[int]; the masklet ids
"masklet_size_rel" : List[float]; the average mask area normalized by resolution
across all the frames where the object is visible
"masklet_size_abs" : List[float]; the average mask area (in pixels)
across all the frames where the object is visible
"masklet_size_bucket" : List[str]; "small": $1$ <= masklet_size_abs < $32^2$,
"medium": $32^2$ <= masklet_size_abs < $96^2$,
and "large": masklet_size_abs > $96^2$
"masklet_visibility_changes" : List[int]; the number of times where the visibility changes
after the first appearance (e.g., invisible -> visible
or visible -> invisible)
"masklet_first_appeared_frame" : List[int]; the index of the frame where the object appears
the first time in the video. Always 0 for auto masklets.
"masklet_frame_count" : List[int]; the number of frames being annotated. Note that
videos are annotated at 6 fps (annotated every 4 frames)
while the videos are at 24 fps.
"masklet_edited_frame_count" : List[int]; the number of frames being edited by human annotators.
Always 0 for auto masklets.
"masklet_type" : List[str]; "auto" or "manual"
"masklet_stability_score" : Optional[List[List[float]]]; per-mask stability scores. Auto annotation only.
"masklet_num" : int; the number of manual/auto masklets in the video
}
```
Note that in SA-V train, there are in total 50,583 videos where all of them have manual annotations. Among the 50,583 videos there are 48,436 videos that also have automatic annotations.
### SA-V val and test
For SA-V val and test sets, we release the extracted frames as jpeg files, and the masks as png files with the following directory structure:
```
sav_val(sav_test)
├── sav_val.txt (sav_test.txt): a list of video ids in the split
├── JPEGImages_24fps # videos are extracted at 24 fps
│ ├── {video_id}
│ │ ├── 00000.jpg # video frame
│ │ ├── 00001.jpg # video frame
│ │ ├── 00002.jpg # video frame
│ │ ├── 00003.jpg # video frame
│ │ └── ...
│ ├── {video_id}
│ ├── {video_id}
│ └── ...
└── Annotations_6fps # videos are annotated at 6 fps
├── {video_id}
│ ├── 000 # obj 000
│ │ ├── 00000.png # mask for object 000 in 00000.jpg
│ │ ├── 00004.png # mask for object 000 in 00004.jpg
│ │ ├── 00008.png # mask for object 000 in 00008.jpg
│ │ ├── 00012.png # mask for object 000 in 00012.jpg
│ │ └── ...
│ ├── 001 # obj 001
│ ├── 002 # obj 002
│ └── ...
├── {video_id}
├── {video_id}
└── ...
```
All masklets in val and test sets are manually annotated in every frame by annotators. For each annotated object in a video, we store the annotated masks in a single png. This is because the annotated objects may overlap, e.g., it is possible in our SA-V dataset for there to be a mask for the whole person as well as a separate mask for their hands.
## SA-V Val and Test Evaluation
We provide an evaluator to compute the common J and F metrics on SA-V val and test sets. To run the evaluation, we need to first install a few dependencies as follows:
```
pip install -r requirements.txt
```
Then we can evaluate the predictions as follows:
```
python sav_evaluator.py --gt_root {GT_ROOT} --pred_root {PRED_ROOT}
```
or run
```
python sav_evaluator.py --help
```
to print a complete help message.
The evaluator expects the `GT_ROOT` to be one of the following folder structures, and `GT_ROOT` and `PRED_ROOT` to have the same structure.
- Same as SA-V val and test directory structure
```
{GT_ROOT} # gt root folder
├── {video_id}
│ ├── 000 # all masks associated with obj 000
│ │ ├── 00000.png # mask for object 000 in frame 00000 (binary mask)
│ │ └── ...
│ ├── 001 # all masks associated with obj 001
│ ├── 002 # all masks associated with obj 002
│ └── ...
├── {video_id}
├── {video_id}
└── ...
```
In the paper for the experiments on SA-V val and test, we run inference on the 24 fps videos, and evaluate on the subset of frames where we have ground truth annotations (first and last annotated frames dropped). The evaluator will ignore the masks in frames where we don't have ground truth annotations.
- Same as [DAVIS](https://github.com/davisvideochallenge/davis2017-evaluation) directory structure
```
{GT_ROOT} # gt root folder
├── {video_id}
│ ├── 00000.png # annotations in frame 00000 (may contain multiple objects)
│ └── ...
├── {video_id}
├── {video_id}
└── ...
```
## License
The evaluation code is licensed under the [BSD 3 license](./LICENSE). Please refer to the paper for more details on the models. The videos and annotations in SA-V Dataset are released under CC BY 4.0.
Third-party code: the evaluation software is heavily adapted from [`VOS-Benchmark`](https://github.com/hkchengrex/vos-benchmark) and [`DAVIS`](https://github.com/davisvideochallenge/davis2017-evaluation) (with their licenses in [`LICENSE_DAVIS`](./LICENSE_DAVIS) and [`LICENSE_VOS_BENCHMARK`](./LICENSE_VOS_BENCHMARK)).
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
pycocoevalcap
scikit-image
opencv-python
tqdm
pillow
numpy
matplotlib
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the sav_dataset directory of this source tree.
# adapted from https://github.com/hkchengrex/vos-benchmark
# and https://github.com/davisvideochallenge/davis2017-evaluation
# with their licenses found in the LICENSE_VOS_BENCHMARK and LICENSE_DAVIS files
# in the sav_dataset directory.
from argparse import ArgumentParser
from utils.sav_benchmark import benchmark
"""
The structure of the {GT_ROOT} can be either of the follow two structures.
{GT_ROOT} and {PRED_ROOT} should be of the same format
1. SA-V val/test structure
{GT_ROOT} # gt root folder
├── {video_id}
│ ├── 000 # all masks associated with obj 000
│ │ ├── {frame_id}.png # mask for object 000 in {frame_id} (binary mask)
│ │ └── ...
│ ├── 001 # all masks associated with obj 001
│ ├── 002 # all masks associated with obj 002
│ └── ...
├── {video_id}
├── {video_id}
└── ...
2. Similar to DAVIS structure:
{GT_ROOT} # gt root folder
├── {video_id}
│ ├── {frame_id}.png # annotation in {frame_id} (may contain multiple objects)
│ └── ...
├── {video_id}
├── {video_id}
└── ...
"""
parser = ArgumentParser()
parser.add_argument(
"--gt_root",
required=True,
help="Path to the GT folder. For SA-V, it's sav_val/Annotations_6fps or sav_test/Annotations_6fps",
)
parser.add_argument(
"--pred_root",
required=True,
help="Path to a folder containing folders of masks to be evaluated, with exactly the same structure as gt_root",
)
parser.add_argument(
"-n", "--num_processes", default=16, type=int, help="Number of concurrent processes"
)
parser.add_argument(
"-s",
"--strict",
help="Make sure every video in the gt_root folder has a corresponding video in the prediction",
action="store_true",
)
parser.add_argument(
"-q",
"--quiet",
help="Quietly run evaluation without printing the information out",
action="store_true",
)
# https://github.com/davisvideochallenge/davis2017-evaluation/blob/d34fdef71ce3cb24c1a167d860b707e575b3034c/davis2017/evaluation.py#L85
parser.add_argument(
"--do_not_skip_first_and_last_frame",
help="In SA-V val and test, we skip the first and the last annotated frames in evaluation. "
"Set this to true for evaluation on settings that doesn't skip first and last frames",
action="store_true",
)
if __name__ == "__main__":
args = parser.parse_args()
benchmark(
[args.gt_root],
[args.pred_root],
args.strict,
args.num_processes,
verbose=not args.quiet,
skip_first_and_last=not args.do_not_skip_first_and_last_frame,
)
This source diff could not be displayed because it is too large. You can view the blob instead.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the sav_dataset directory of this source tree.
# adapted from https://github.com/hkchengrex/vos-benchmark
# and https://github.com/davisvideochallenge/davis2017-evaluation
# with their licenses found in the LICENSE_VOS_BENCHMARK and LICENSE_DAVIS files
# in the sav_dataset directory.
import math
import os
import time
from collections import defaultdict
from multiprocessing import Pool
from os import path
from typing import Dict, List, Tuple
import cv2
import numpy as np
import tqdm
from PIL import Image
from skimage.morphology import disk
class VideoEvaluator:
def __init__(self, gt_root, pred_root, skip_first_and_last=True) -> None:
"""
gt_root: path to the folder storing the gt masks
pred_root: path to the folder storing the predicted masks
skip_first_and_last: whether we should skip the evaluation of the first and the last frame.
True for SA-V val and test, same as in DAVIS semi-supervised evaluation.
"""
self.gt_root = gt_root
self.pred_root = pred_root
self.skip_first_and_last = skip_first_and_last
def __call__(self, vid_name: str) -> Tuple[str, Dict[str, float], Dict[str, float]]:
"""
vid_name: name of the video to evaluate
"""
# scan the folder to find subfolders for evaluation and
# check if the folder structure is SA-V
to_evaluate, is_sav_format = self.scan_vid_folder(vid_name)
# evaluate each (gt_path, pred_path) pair
eval_results = []
for all_frames, obj_id, gt_path, pred_path in to_evaluate:
if self.skip_first_and_last:
# skip the first and the last frames
all_frames = all_frames[1:-1]
evaluator = Evaluator(name=vid_name, obj_id=obj_id)
for frame in all_frames:
gt_array, pred_array = self.get_gt_and_pred(
gt_path, pred_path, frame, is_sav_format
)
evaluator.feed_frame(mask=pred_array, gt=gt_array)
iou, boundary_f = evaluator.conclude()
eval_results.append((obj_id, iou, boundary_f))
if is_sav_format:
iou_output, boundary_f_output = self.consolidate(eval_results)
else:
assert len(eval_results) == 1
iou_output = eval_results[0][1]
boundary_f_output = eval_results[0][2]
return vid_name, iou_output, boundary_f_output
def get_gt_and_pred(
self,
gt_path: str,
pred_path: str,
f_name: str,
is_sav_format: bool,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Get the ground-truth and predicted masks for a single frame.
"""
gt_mask_path = path.join(gt_path, f_name)
pred_mask_path = path.join(pred_path, f_name)
assert os.path.exists(pred_mask_path), f"{pred_mask_path} not found"
gt_array = np.array(Image.open(gt_mask_path))
pred_array = np.array(Image.open(pred_mask_path))
assert (
gt_array.shape[-2:] == pred_array.shape[-2:]
), f"shape mismatch: {gt_mask_path}, {pred_mask_path}"
if is_sav_format:
assert len(np.unique(gt_array)) <= 2, (
f"found more than 1 object in {gt_mask_path} "
"SA-V format assumes one object mask per png file."
)
assert len(np.unique(pred_array)) <= 2, (
f"found more than 1 object in {pred_mask_path} "
"SA-V format assumes one object mask per png file."
)
gt_array = gt_array > 0
pred_array = pred_array > 0
return gt_array, pred_array
def scan_vid_folder(self, vid_name) -> Tuple[List, bool]:
"""
Scan the folder structure of the video and return a list of folders for evaluate.
"""
vid_gt_path = path.join(self.gt_root, vid_name)
vid_pred_path = path.join(self.pred_root, vid_name)
all_files_and_dirs = sorted(os.listdir(vid_gt_path))
to_evaluate = []
if all(name.endswith(".png") for name in all_files_and_dirs):
# All files are png files, dataset structure similar to DAVIS
is_sav_format = False
frames = all_files_and_dirs
obj_dir = None
to_evaluate.append((frames, obj_dir, vid_gt_path, vid_pred_path))
else:
# SA-V dataset structure, going one layer down into each subdirectory
is_sav_format = True
for obj_dir in all_files_and_dirs:
obj_gt_path = path.join(vid_gt_path, obj_dir)
obj_pred_path = path.join(vid_pred_path, obj_dir)
frames = sorted(os.listdir(obj_gt_path))
to_evaluate.append((frames, obj_dir, obj_gt_path, obj_pred_path))
return to_evaluate, is_sav_format
def consolidate(
self, eval_results
) -> Tuple[str, Dict[str, float], Dict[str, float]]:
"""
Consolidate the results of all the objects from the video into one dictionary.
"""
iou_output = {}
boundary_f_output = {}
for obj_id, iou, boundary_f in eval_results:
assert len(iou) == 1
key = list(iou.keys())[0]
iou_output[obj_id] = iou[key]
boundary_f_output[obj_id] = boundary_f[key]
return iou_output, boundary_f_output
#################################################################################################################
# Functions below are from https://github.com/hkchengrex/vos-benchmark with minor modifications
# _seg2bmap from https://github.com/hkchengrex/vos-benchmark/blob/main/vos_benchmark/utils.py
# get_iou and Evaluator from https://github.com/hkchengrex/vos-benchmark/blob/main/vos_benchmark/evaluator.py
# benchmark from https://github.com/hkchengrex/vos-benchmark/blob/main/vos_benchmark/benchmark.py with slight mod
#################################################################################################################
def _seg2bmap(seg, width=None, height=None):
"""
From a segmentation, compute a binary boundary map with 1 pixel wide
boundaries. The boundary pixels are offset by 1/2 pixel towards the
origin from the actual segment boundary.
Arguments:
seg : Segments labeled from 1..k.
width : Width of desired bmap <= seg.shape[1]
height : Height of desired bmap <= seg.shape[0]
Returns:
bmap (ndarray): Binary boundary map.
David Martin <dmartin@eecs.berkeley.edu>
January 2003
"""
seg = seg.astype(bool)
seg[seg > 0] = 1
assert np.atleast_3d(seg).shape[2] == 1
width = seg.shape[1] if width is None else width
height = seg.shape[0] if height is None else height
h, w = seg.shape[:2]
ar1 = float(width) / float(height)
ar2 = float(w) / float(h)
assert not (
width > w | height > h | abs(ar1 - ar2) > 0.01
), "Cannot convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
e = np.zeros_like(seg)
s = np.zeros_like(seg)
se = np.zeros_like(seg)
e[:, :-1] = seg[:, 1:]
s[:-1, :] = seg[1:, :]
se[:-1, :-1] = seg[1:, 1:]
b = seg ^ e | seg ^ s | seg ^ se
b[-1, :] = seg[-1, :] ^ e[-1, :]
b[:, -1] = seg[:, -1] ^ s[:, -1]
b[-1, -1] = 0
if w == width and h == height:
bmap = b
else:
bmap = np.zeros((height, width))
for x in range(w):
for y in range(h):
if b[y, x]:
j = 1 + math.floor((y - 1) + height / h)
i = 1 + math.floor((x - 1) + width / h)
bmap[j, i] = 1
return bmap
def get_iou(intersection, pixel_sum):
# handle edge cases without resorting to epsilon
if intersection == pixel_sum:
# both mask and gt have zero pixels in them
assert intersection == 0
return 1
return intersection / (pixel_sum - intersection)
class Evaluator:
def __init__(self, boundary=0.008, name=None, obj_id=None):
# boundary: used in computing boundary F-score
self.boundary = boundary
self.name = name
self.obj_id = obj_id
self.objects_in_gt = set()
self.objects_in_masks = set()
self.object_iou = defaultdict(list)
self.boundary_f = defaultdict(list)
def feed_frame(self, mask: np.ndarray, gt: np.ndarray):
"""
Compute and accumulate metrics for a single frame (mask/gt pair)
"""
# get all objects in the ground-truth
gt_objects = np.unique(gt)
gt_objects = gt_objects[gt_objects != 0].tolist()
# get all objects in the predicted mask
mask_objects = np.unique(mask)
mask_objects = mask_objects[mask_objects != 0].tolist()
self.objects_in_gt.update(set(gt_objects))
self.objects_in_masks.update(set(mask_objects))
all_objects = self.objects_in_gt.union(self.objects_in_masks)
# boundary disk for boundary F-score. It is the same for all objects.
bound_pix = np.ceil(self.boundary * np.linalg.norm(mask.shape))
boundary_disk = disk(bound_pix)
for obj_idx in all_objects:
obj_mask = mask == obj_idx
obj_gt = gt == obj_idx
# object iou
self.object_iou[obj_idx].append(
get_iou((obj_mask * obj_gt).sum(), obj_mask.sum() + obj_gt.sum())
)
"""
# boundary f-score
This part is copied from davis2017-evaluation
"""
mask_boundary = _seg2bmap(obj_mask)
gt_boundary = _seg2bmap(obj_gt)
mask_dilated = cv2.dilate(mask_boundary.astype(np.uint8), boundary_disk)
gt_dilated = cv2.dilate(gt_boundary.astype(np.uint8), boundary_disk)
# Get the intersection
gt_match = gt_boundary * mask_dilated
fg_match = mask_boundary * gt_dilated
# Area of the intersection
n_fg = np.sum(mask_boundary)
n_gt = np.sum(gt_boundary)
# Compute precision and recall
if n_fg == 0 and n_gt > 0:
precision = 1
recall = 0
elif n_fg > 0 and n_gt == 0:
precision = 0
recall = 1
elif n_fg == 0 and n_gt == 0:
precision = 1
recall = 1
else:
precision = np.sum(fg_match) / float(n_fg)
recall = np.sum(gt_match) / float(n_gt)
# Compute F measure
if precision + recall == 0:
F = 0
else:
F = 2 * precision * recall / (precision + recall)
self.boundary_f[obj_idx].append(F)
def conclude(self):
all_iou = {}
all_boundary_f = {}
for object_id in self.objects_in_gt:
all_iou[object_id] = np.mean(self.object_iou[object_id]) * 100
all_boundary_f[object_id] = np.mean(self.boundary_f[object_id]) * 100
return all_iou, all_boundary_f
def benchmark(
gt_roots,
mask_roots,
strict=True,
num_processes=None,
*,
verbose=True,
skip_first_and_last=True,
):
"""
gt_roots: a list of paths to datasets, i.e., [path_to_DatasetA, path_to_DatasetB, ...]
mask_roots: same as above, but the .png are masks predicted by the model
strict: when True, all videos in the dataset must have corresponding predictions.
Setting it to False is useful in cases where the ground-truth contains both train/val
sets, but the model only predicts the val subset.
Either way, if a video is predicted (i.e., the corresponding folder exists),
then it must at least contain all the masks in the ground truth annotations.
Masks that are in the prediction but not in the ground-truth
(i.e., sparse annotations) are ignored.
skip_first_and_last: whether we should skip the first and the last frame in evaluation.
This is used by DAVIS 2017 in their semi-supervised evaluation.
It should be disabled for unsupervised evaluation.
"""
assert len(gt_roots) == len(mask_roots)
single_dataset = len(gt_roots) == 1
if verbose:
if skip_first_and_last:
print(
"We are *SKIPPING* the evaluation of the first and the last frame (standard for semi-supervised video object segmentation)."
)
else:
print(
"We are *NOT SKIPPING* the evaluation of the first and the last frame (*NOT STANDARD* for semi-supervised video object segmentation)."
)
pool = Pool(num_processes)
start = time.time()
to_wait = []
for gt_root, mask_root in zip(gt_roots, mask_roots):
# Validate folders
validated = True
gt_videos = os.listdir(gt_root)
mask_videos = os.listdir(mask_root)
# if the user passed the root directory instead of Annotations
if len(gt_videos) != len(mask_videos):
if "Annotations" in gt_videos:
if ".png" not in os.listdir(path.join(gt_root, "Annotations"))[0]:
gt_root = path.join(gt_root, "Annotations")
gt_videos = os.listdir(gt_root)
# remove non-folder items
gt_videos = list(filter(lambda x: path.isdir(path.join(gt_root, x)), gt_videos))
mask_videos = list(
filter(lambda x: path.isdir(path.join(mask_root, x)), mask_videos)
)
if not strict:
videos = sorted(list(set(gt_videos) & set(mask_videos)))
else:
gt_extras = set(gt_videos) - set(mask_videos)
mask_extras = set(mask_videos) - set(gt_videos)
if len(gt_extras) > 0:
print(
f"Videos that are in {gt_root} but not in {mask_root}: {gt_extras}"
)
validated = False
if len(mask_extras) > 0:
print(
f"Videos that are in {mask_root} but not in {gt_root}: {mask_extras}"
)
validated = False
if not validated:
print("Validation failed. Exiting.")
exit(1)
videos = sorted(gt_videos)
if verbose:
print(
f"In dataset {gt_root}, we are evaluating on {len(videos)} videos: {videos}"
)
if single_dataset:
if verbose:
results = tqdm.tqdm(
pool.imap(
VideoEvaluator(
gt_root, mask_root, skip_first_and_last=skip_first_and_last
),
videos,
),
total=len(videos),
)
else:
results = pool.map(
VideoEvaluator(
gt_root, mask_root, skip_first_and_last=skip_first_and_last
),
videos,
)
else:
to_wait.append(
pool.map_async(
VideoEvaluator(
gt_root, mask_root, skip_first_and_last=skip_first_and_last
),
videos,
)
)
pool.close()
all_global_jf, all_global_j, all_global_f = [], [], []
all_object_metrics = []
for i, mask_root in enumerate(mask_roots):
if not single_dataset:
results = to_wait[i].get()
all_iou = []
all_boundary_f = []
object_metrics = {}
for name, iou, boundary_f in results:
all_iou.extend(list(iou.values()))
all_boundary_f.extend(list(boundary_f.values()))
object_metrics[name] = (iou, boundary_f)
global_j = np.array(all_iou).mean()
global_f = np.array(all_boundary_f).mean()
global_jf = (global_j + global_f) / 2
time_taken = time.time() - start
"""
Build string for reporting results
"""
# find max length for padding
ml = max(*[len(n) for n in object_metrics.keys()], len("Global score"))
# build header
out_string = f'{"sequence":<{ml}},{"obj":>3}, {"J&F":>4}, {"J":>4}, {"F":>4}\n'
out_string += f'{"Global score":<{ml}},{"":>3}, {global_jf:.1f}, {global_j:.1f}, {global_f:.1f}\n'
# append one line for each object
for name, (iou, boundary_f) in object_metrics.items():
for object_idx in iou.keys():
j, f = iou[object_idx], boundary_f[object_idx]
jf = (j + f) / 2
out_string += (
f"{name:<{ml}},{object_idx:03}, {jf:>4.1f}, {j:>4.1f}, {f:>4.1f}\n"
)
# print to console
if verbose:
print(out_string.replace(",", " "), end="")
print("\nSummary:")
print(
f"Global score: J&F: {global_jf:.1f} J: {global_j:.1f} F: {global_f:.1f}"
)
print(f"Time taken: {time_taken:.2f}s")
# print to file
result_path = path.join(mask_root, "results.csv")
print(f"Saving the results to {result_path}")
with open(result_path, "w") as f:
f.write(out_string)
all_global_jf.append(global_jf)
all_global_j.append(global_j)
all_global_f.append(global_f)
all_object_metrics.append(object_metrics)
return all_global_jf, all_global_j, all_global_f, all_object_metrics
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the sav_dataset directory of this source tree.
import json
import os
from typing import Dict, List, Optional, Tuple
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pycocotools.mask as mask_util
def decode_video(video_path: str) -> List[np.ndarray]:
"""
Decode the video and return the RGB frames
"""
video = cv2.VideoCapture(video_path)
video_frames = []
while video.isOpened():
ret, frame = video.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
video_frames.append(frame)
else:
break
return video_frames
def show_anns(masks, colors: List, borders=True) -> None:
"""
show the annotations
"""
# return if no masks
if len(masks) == 0:
return
# sort masks by size
sorted_annot_and_color = sorted(
zip(masks, colors), key=(lambda x: x[0].sum()), reverse=True
)
H, W = sorted_annot_and_color[0][0].shape[0], sorted_annot_and_color[0][0].shape[1]
canvas = np.ones((H, W, 4))
canvas[:, :, 3] = 0 # set the alpha channel
contour_thickness = max(1, int(min(5, 0.01 * min(H, W))))
for mask, color in sorted_annot_and_color:
canvas[mask] = np.concatenate([color, [0.55]])
if borders:
contours, _ = cv2.findContours(
np.array(mask, dtype=np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE
)
cv2.drawContours(
canvas, contours, -1, (0.05, 0.05, 0.05, 1), thickness=contour_thickness
)
ax = plt.gca()
ax.imshow(canvas)
class SAVDataset:
"""
SAVDataset is a class to load the SAV dataset and visualize the annotations.
"""
def __init__(self, sav_dir, annot_sample_rate=4):
"""
Args:
sav_dir: the directory of the SAV dataset
annot_sample_rate: the sampling rate of the annotations.
The annotations are aligned with the videos at 6 fps.
"""
self.sav_dir = sav_dir
self.annot_sample_rate = annot_sample_rate
self.manual_mask_colors = np.random.random((256, 3))
self.auto_mask_colors = np.random.random((256, 3))
def read_frames(self, mp4_path: str) -> None:
"""
Read the frames and downsample them to align with the annotations.
"""
if not os.path.exists(mp4_path):
print(f"{mp4_path} doesn't exist.")
return None
else:
# decode the video
frames = decode_video(mp4_path)
print(f"There are {len(frames)} frames decoded from {mp4_path} (24fps).")
# downsample the frames to align with the annotations
frames = frames[:: self.annot_sample_rate]
print(
f"Videos are annotated every {self.annot_sample_rate} frames. "
"To align with the annotations, "
f"downsample the video to {len(frames)} frames."
)
return frames
def get_frames_and_annotations(
self, video_id: str
) -> Tuple[List | None, Dict | None, Dict | None]:
"""
Get the frames and annotations for video.
"""
# load the video
mp4_path = os.path.join(self.sav_dir, video_id + ".mp4")
frames = self.read_frames(mp4_path)
if frames is None:
return None, None, None
# load the manual annotations
manual_annot_path = os.path.join(self.sav_dir, video_id + "_manual.json")
if not os.path.exists(manual_annot_path):
print(f"{manual_annot_path} doesn't exist. Something might be wrong.")
manual_annot = None
else:
manual_annot = json.load(open(manual_annot_path))
# load the manual annotations
auto_annot_path = os.path.join(self.sav_dir, video_id + "_auto.json")
if not os.path.exists(auto_annot_path):
print(f"{auto_annot_path} doesn't exist.")
auto_annot = None
else:
auto_annot = json.load(open(auto_annot_path))
return frames, manual_annot, auto_annot
def visualize_annotation(
self,
frames: List[np.ndarray],
auto_annot: Optional[Dict],
manual_annot: Optional[Dict],
annotated_frame_id: int,
show_auto=True,
show_manual=True,
) -> None:
"""
Visualize the annotations on the annotated_frame_id.
If show_manual is True, show the manual annotations.
If show_auto is True, show the auto annotations.
By default, show both auto and manual annotations.
"""
if annotated_frame_id >= len(frames):
print("invalid annotated_frame_id")
return
rles = []
colors = []
if show_manual and manual_annot is not None:
rles.extend(manual_annot["masklet"][annotated_frame_id])
colors.extend(
self.manual_mask_colors[
: len(manual_annot["masklet"][annotated_frame_id])
]
)
if show_auto and auto_annot is not None:
rles.extend(auto_annot["masklet"][annotated_frame_id])
colors.extend(
self.auto_mask_colors[: len(auto_annot["masklet"][annotated_frame_id])]
)
plt.imshow(frames[annotated_frame_id])
if len(rles) > 0:
masks = [mask_util.decode(rle) > 0 for rle in rles]
show_anns(masks, colors)
else:
print("No annotation will be shown")
plt.axis("off")
plt.show()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from setuptools import find_packages, setup
# Package metadata
NAME = "SAM-2"
VERSION = "1.0"
DESCRIPTION = "SAM 2: Segment Anything in Images and Videos"
URL = "https://github.com/facebookresearch/sam2"
AUTHOR = "Meta AI"
AUTHOR_EMAIL = "segment-anything@meta.com"
LICENSE = "Apache 2.0"
# Read the contents of README file
with open("README.md", "r", encoding="utf-8") as f:
LONG_DESCRIPTION = f.read()
# Required dependencies
REQUIRED_PACKAGES = [
# "torch>=2.5.1",
# "torchvision>=0.20.1",
"numpy>=1.24.4",
"tqdm>=4.66.1",
"hydra-core>=1.3.2",
"iopath>=0.1.10",
"pillow>=9.4.0",
]
EXTRA_PACKAGES = {
"notebooks": [
"matplotlib>=3.9.1",
"jupyter>=1.0.0",
"opencv-python>=4.7.0",
"eva-decord>=0.6.1",
],
"interactive-demo": [
"Flask>=3.0.3",
"Flask-Cors>=5.0.0",
"av>=13.0.0",
"dataclasses-json>=0.6.7",
"eva-decord>=0.6.1",
"gunicorn>=23.0.0",
"imagesize>=1.4.1",
"pycocotools>=2.0.8",
"strawberry-graphql>=0.243.0",
],
"dev": [
"black==24.2.0",
"usort==1.0.2",
"ufmt==2.0.0b2",
"fvcore>=0.1.5.post20221221",
"pandas>=2.2.2",
"scikit-image>=0.24.0",
"tensorboard>=2.17.0",
"pycocotools>=2.0.8",
"tensordict>=0.6.0",
"opencv-python>=4.7.0",
"submitit>=1.5.1",
],
}
# By default, we also build the SAM 2 CUDA extension.
# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
# By default, we allow SAM 2 installation to proceed even with build errors.
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
# Catch and skip errors during extension building and print a warning message
# (note that this message only shows up under verbose build mode
# "pip install -v -e ." or "python setup.py build_ext -v")
CUDA_ERROR_MSG = (
"{}\n\n"
"Failed to build the SAM 2 CUDA extension due to the error above. "
"You can still use SAM 2 and it's OK to ignore the error above, although some "
"post-processing functionality may be limited (which doesn't affect the results in most cases; "
"(see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n"
)
def get_extensions():
if not BUILD_CUDA:
return []
try:
from torch.utils.cpp_extension import CUDAExtension
srcs = ["sam2/csrc/connected_components.cu"]
compile_args = {
"cxx": [],
"nvcc": [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
}
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
except Exception as e:
if BUILD_ALLOW_ERRORS:
print(CUDA_ERROR_MSG.format(e))
ext_modules = []
else:
raise e
return ext_modules
try:
from torch.utils.cpp_extension import BuildExtension
class BuildExtensionIgnoreErrors(BuildExtension):
def finalize_options(self):
try:
super().finalize_options()
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
def build_extensions(self):
try:
super().build_extensions()
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
def get_ext_filename(self, ext_name):
try:
return super().get_ext_filename(ext_name)
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
return "_C.so"
cmdclass = {
"build_ext": (
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
if BUILD_ALLOW_ERRORS
else BuildExtension.with_options(no_python_abi_suffix=True)
)
}
except Exception as e:
cmdclass = {}
if BUILD_ALLOW_ERRORS:
print(CUDA_ERROR_MSG.format(e))
else:
raise e
# Setup configuration
setup(
name=NAME,
version=VERSION,
description=DESCRIPTION,
long_description=LONG_DESCRIPTION,
long_description_content_type="text/markdown",
url=URL,
author=AUTHOR,
author_email=AUTHOR_EMAIL,
license=LICENSE,
packages=find_packages(exclude="notebooks"),
include_package_data=True,
install_requires=REQUIRED_PACKAGES,
extras_require=EXTRA_PACKAGES,
python_requires=">=3.10.0",
ext_modules=get_extensions(),
cmdclass=cmdclass,
)
## SAM 2 toolkits
This directory provides toolkits for additional SAM 2 use cases.
### Semi-supervised VOS inference
The `vos_inference.py` script can be used to generate predictions for semi-supervised video object segmentation (VOS) evaluation on datasets such as [DAVIS](https://davischallenge.org/index.html), [MOSE](https://henghuiding.github.io/MOSE/) or the SA-V dataset.
After installing SAM 2 and its dependencies, it can be used as follows ([DAVIS 2017 dataset](https://davischallenge.org/davis2017/code.html) as an example). This script saves the prediction PNG files to the `--output_mask_dir`.
```bash
python ./tools/vos_inference.py \
--sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \
--sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \
--base_video_dir /path-to-davis-2017/JPEGImages/480p \
--input_mask_dir /path-to-davis-2017/Annotations/480p \
--video_list_file /path-to-davis-2017/ImageSets/2017/val.txt \
--output_mask_dir ./outputs/davis_2017_pred_pngs
```
(replace `/path-to-davis-2017` with the path to DAVIS 2017 dataset)
To evaluate on the SA-V dataset with per-object PNG files for the object masks, we need to **add the `--per_obj_png_file` flag** as follows (using SA-V val as an example). This script will also save per-object PNG files for the output masks under the `--per_obj_png_file` flag.
```bash
python ./tools/vos_inference.py \
--sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \
--sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \
--base_video_dir /path-to-sav-val/JPEGImages_24fps \
--input_mask_dir /path-to-sav-val/Annotations_6fps \
--video_list_file /path-to-sav-val/sav_val.txt \
--per_obj_png_file \
--output_mask_dir ./outputs/sav_val_pred_pngs
```
(replace `/path-to-sav-val` with the path to SA-V val)
Then, we can use the evaluation tools or servers for each dataset to get the performance of the prediction PNG files above.
Note: by default, the `vos_inference.py` script above assumes that all objects to track already appear on frame 0 in each video (as is the case in DAVIS, MOSE or SA-V). **For VOS datasets that don't have all objects to track appearing in the first frame (such as LVOS or YouTube-VOS), please add the `--track_object_appearing_later_in_video` flag when using `vos_inference.py`**.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
from collections import defaultdict
import numpy as np
import torch
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor
# the PNG palette for DAVIS 2017 dataset
DAVIS_PALETTE = b"\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0"
def load_ann_png(path):
"""Load a PNG file as a mask and its palette."""
mask = Image.open(path)
palette = mask.getpalette()
mask = np.array(mask).astype(np.uint8)
return mask, palette
def save_ann_png(path, mask, palette):
"""Save a mask as a PNG file with the given palette."""
assert mask.dtype == np.uint8
assert mask.ndim == 2
output_mask = Image.fromarray(mask)
output_mask.putpalette(palette)
output_mask.save(path)
def get_per_obj_mask(mask):
"""Split a mask into per-object masks."""
object_ids = np.unique(mask)
object_ids = object_ids[object_ids > 0].tolist()
per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids}
return per_obj_mask
def put_per_obj_mask(per_obj_mask, height, width):
"""Combine per-object masks into a single mask."""
mask = np.zeros((height, width), dtype=np.uint8)
object_ids = sorted(per_obj_mask)[::-1]
for object_id in object_ids:
object_mask = per_obj_mask[object_id]
object_mask = object_mask.reshape(height, width)
mask[object_mask] = object_id
return mask
def load_masks_from_dir(
input_mask_dir, video_name, frame_name, per_obj_png_file, allow_missing=False
):
"""Load masks from a directory as a dict of per-object masks."""
if not per_obj_png_file:
input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png")
if allow_missing and not os.path.exists(input_mask_path):
return {}, None
input_mask, input_palette = load_ann_png(input_mask_path)
per_obj_input_mask = get_per_obj_mask(input_mask)
else:
per_obj_input_mask = {}
input_palette = None
# each object is a directory in "{object_id:%03d}" format
for object_name in os.listdir(os.path.join(input_mask_dir, video_name)):
object_id = int(object_name)
input_mask_path = os.path.join(
input_mask_dir, video_name, object_name, f"{frame_name}.png"
)
if allow_missing and not os.path.exists(input_mask_path):
continue
input_mask, input_palette = load_ann_png(input_mask_path)
per_obj_input_mask[object_id] = input_mask > 0
return per_obj_input_mask, input_palette
def save_masks_to_dir(
output_mask_dir,
video_name,
frame_name,
per_obj_output_mask,
height,
width,
per_obj_png_file,
output_palette,
):
"""Save masks to a directory as PNG files."""
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
if not per_obj_png_file:
output_mask = put_per_obj_mask(per_obj_output_mask, height, width)
output_mask_path = os.path.join(
output_mask_dir, video_name, f"{frame_name}.png"
)
save_ann_png(output_mask_path, output_mask, output_palette)
else:
for object_id, object_mask in per_obj_output_mask.items():
object_name = f"{object_id:03d}"
os.makedirs(
os.path.join(output_mask_dir, video_name, object_name),
exist_ok=True,
)
output_mask = object_mask.reshape(height, width).astype(np.uint8)
output_mask_path = os.path.join(
output_mask_dir, video_name, object_name, f"{frame_name}.png"
)
save_ann_png(output_mask_path, output_mask, output_palette)
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def vos_inference(
predictor,
base_video_dir,
input_mask_dir,
output_mask_dir,
video_name,
score_thresh=0.0,
use_all_masks=False,
per_obj_png_file=False,
):
"""Run VOS inference on a single video with the given predictor."""
# load the video frames and initialize the inference state on this video
video_dir = os.path.join(base_video_dir, video_name)
frame_names = [
os.path.splitext(p)[0]
for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(
video_path=video_dir, async_loading_frames=False
)
height = inference_state["video_height"]
width = inference_state["video_width"]
input_palette = None
# fetch mask inputs from input_mask_dir (either only mask for the first frame, or all available masks)
if not use_all_masks:
# use only the first video's ground-truth mask as the input mask
input_frame_inds = [0]
else:
# use all mask files available in the input_mask_dir as the input masks
if not per_obj_png_file:
input_frame_inds = [
idx
for idx, name in enumerate(frame_names)
if os.path.exists(
os.path.join(input_mask_dir, video_name, f"{name}.png")
)
]
else:
input_frame_inds = [
idx
for object_name in os.listdir(os.path.join(input_mask_dir, video_name))
for idx, name in enumerate(frame_names)
if os.path.exists(
os.path.join(input_mask_dir, video_name, object_name, f"{name}.png")
)
]
# check and make sure we got at least one input frame
if len(input_frame_inds) == 0:
raise RuntimeError(
f"In {video_name=}, got no input masks in {input_mask_dir=}. "
"Please make sure the input masks are available in the correct format."
)
input_frame_inds = sorted(set(input_frame_inds))
# add those input masks to SAM 2 inference state before propagation
object_ids_set = None
for input_frame_idx in input_frame_inds:
try:
per_obj_input_mask, input_palette = load_masks_from_dir(
input_mask_dir=input_mask_dir,
video_name=video_name,
frame_name=frame_names[input_frame_idx],
per_obj_png_file=per_obj_png_file,
)
except FileNotFoundError as e:
raise RuntimeError(
f"In {video_name=}, failed to load input mask for frame {input_frame_idx=}. "
"Please add the `--track_object_appearing_later_in_video` flag "
"for VOS datasets that don't have all objects to track appearing "
"in the first frame (such as LVOS or YouTube-VOS)."
) from e
# get the list of object ids to track from the first input frame
if object_ids_set is None:
object_ids_set = set(per_obj_input_mask)
for object_id, object_mask in per_obj_input_mask.items():
# check and make sure no new object ids appear only in later frames
if object_id not in object_ids_set:
raise RuntimeError(
f"In {video_name=}, got a new {object_id=} appearing only in a "
f"later {input_frame_idx=} (but not appearing in the first frame). "
"Please add the `--track_object_appearing_later_in_video` flag "
"for VOS datasets that don't have all objects to track appearing "
"in the first frame (such as LVOS or YouTube-VOS)."
)
predictor.add_new_mask(
inference_state=inference_state,
frame_idx=input_frame_idx,
obj_id=object_id,
mask=object_mask,
)
# check and make sure we have at least one object to track
if object_ids_set is None or len(object_ids_set) == 0:
raise RuntimeError(
f"In {video_name=}, got no object ids on {input_frame_inds=}. "
"Please add the `--track_object_appearing_later_in_video` flag "
"for VOS datasets that don't have all objects to track appearing "
"in the first frame (such as LVOS or YouTube-VOS)."
)
# run propagation throughout the video and collect the results in a dict
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
output_palette = input_palette or DAVIS_PALETTE
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
inference_state
):
per_obj_output_mask = {
out_obj_id: (out_mask_logits[i] > score_thresh).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
video_segments[out_frame_idx] = per_obj_output_mask
# write the output masks as palette PNG files to output_mask_dir
for out_frame_idx, per_obj_output_mask in video_segments.items():
save_masks_to_dir(
output_mask_dir=output_mask_dir,
video_name=video_name,
frame_name=frame_names[out_frame_idx],
per_obj_output_mask=per_obj_output_mask,
height=height,
width=width,
per_obj_png_file=per_obj_png_file,
output_palette=output_palette,
)
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def vos_separate_inference_per_object(
predictor,
base_video_dir,
input_mask_dir,
output_mask_dir,
video_name,
score_thresh=0.0,
use_all_masks=False,
per_obj_png_file=False,
):
"""
Run VOS inference on a single video with the given predictor.
Unlike `vos_inference`, this function run inference separately for each object
in a video, which could be applied to datasets like LVOS or YouTube-VOS that
don't have all objects to track appearing in the first frame (i.e. some objects
might appear only later in the video).
"""
# load the video frames and initialize the inference state on this video
video_dir = os.path.join(base_video_dir, video_name)
frame_names = [
os.path.splitext(p)[0]
for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(
video_path=video_dir, async_loading_frames=False
)
height = inference_state["video_height"]
width = inference_state["video_width"]
input_palette = None
# collect all the object ids and their input masks
inputs_per_object = defaultdict(dict)
for idx, name in enumerate(frame_names):
if per_obj_png_file or os.path.exists(
os.path.join(input_mask_dir, video_name, f"{name}.png")
):
per_obj_input_mask, input_palette = load_masks_from_dir(
input_mask_dir=input_mask_dir,
video_name=video_name,
frame_name=frame_names[idx],
per_obj_png_file=per_obj_png_file,
allow_missing=True,
)
for object_id, object_mask in per_obj_input_mask.items():
# skip empty masks
if not np.any(object_mask):
continue
# if `use_all_masks=False`, we only use the first mask for each object
if len(inputs_per_object[object_id]) > 0 and not use_all_masks:
continue
print(f"adding mask from frame {idx} as input for {object_id=}")
inputs_per_object[object_id][idx] = object_mask
# run inference separately for each object in the video
object_ids = sorted(inputs_per_object)
output_scores_per_object = defaultdict(dict)
for object_id in object_ids:
# add those input masks to SAM 2 inference state before propagation
input_frame_inds = sorted(inputs_per_object[object_id])
predictor.reset_state(inference_state)
for input_frame_idx in input_frame_inds:
predictor.add_new_mask(
inference_state=inference_state,
frame_idx=input_frame_idx,
obj_id=object_id,
mask=inputs_per_object[object_id][input_frame_idx],
)
# run propagation throughout the video and collect the results in a dict
for out_frame_idx, _, out_mask_logits in predictor.propagate_in_video(
inference_state,
start_frame_idx=min(input_frame_inds),
reverse=False,
):
obj_scores = out_mask_logits.cpu().numpy()
output_scores_per_object[object_id][out_frame_idx] = obj_scores
# post-processing: consolidate the per-object scores into per-frame masks
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
output_palette = input_palette or DAVIS_PALETTE
video_segments = {} # video_segments contains the per-frame segmentation results
for frame_idx in range(len(frame_names)):
scores = torch.full(
size=(len(object_ids), 1, height, width),
fill_value=-1024.0,
dtype=torch.float32,
)
for i, object_id in enumerate(object_ids):
if frame_idx in output_scores_per_object[object_id]:
scores[i] = torch.from_numpy(
output_scores_per_object[object_id][frame_idx]
)
if not per_obj_png_file:
scores = predictor._apply_non_overlapping_constraints(scores)
per_obj_output_mask = {
object_id: (scores[i] > score_thresh).cpu().numpy()
for i, object_id in enumerate(object_ids)
}
video_segments[frame_idx] = per_obj_output_mask
# write the output masks as palette PNG files to output_mask_dir
for frame_idx, per_obj_output_mask in video_segments.items():
save_masks_to_dir(
output_mask_dir=output_mask_dir,
video_name=video_name,
frame_name=frame_names[frame_idx],
per_obj_output_mask=per_obj_output_mask,
height=height,
width=width,
per_obj_png_file=per_obj_png_file,
output_palette=output_palette,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--sam2_cfg",
type=str,
default="configs/sam2.1/sam2.1_hiera_b+.yaml",
help="SAM 2 model configuration file",
)
parser.add_argument(
"--sam2_checkpoint",
type=str,
default="./checkpoints/sam2.1_hiera_base_plus.pt",
help="path to the SAM 2 model checkpoint",
)
parser.add_argument(
"--base_video_dir",
type=str,
required=True,
help="directory containing videos (as JPEG files) to run VOS prediction on",
)
parser.add_argument(
"--input_mask_dir",
type=str,
required=True,
help="directory containing input masks (as PNG files) of each video",
)
parser.add_argument(
"--video_list_file",
type=str,
default=None,
help="text file containing the list of video names to run VOS prediction on",
)
parser.add_argument(
"--output_mask_dir",
type=str,
required=True,
help="directory to save the output masks (as PNG files)",
)
parser.add_argument(
"--score_thresh",
type=float,
default=0.0,
help="threshold for the output mask logits (default: 0.0)",
)
parser.add_argument(
"--use_all_masks",
action="store_true",
help="whether to use all available PNG files in input_mask_dir "
"(default without this flag: just the first PNG file as input to the SAM 2 model; "
"usually we don't need this flag, since semi-supervised VOS evaluation usually takes input from the first frame only)",
)
parser.add_argument(
"--per_obj_png_file",
action="store_true",
help="whether use separate per-object PNG files for input and output masks "
"(default without this flag: all object masks are packed into a single PNG file on each frame following DAVIS format; "
"note that the SA-V dataset stores each object mask as an individual PNG file and requires this flag)",
)
parser.add_argument(
"--apply_postprocessing",
action="store_true",
help="whether to apply postprocessing (e.g. hole-filling) to the output masks "
"(we don't apply such post-processing in the SAM 2 model evaluation)",
)
parser.add_argument(
"--track_object_appearing_later_in_video",
action="store_true",
help="whether to track objects that appear later in the video (i.e. not on the first frame; "
"some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
)
parser.add_argument(
"--use_vos_optimized_video_predictor",
action="store_true",
help="whether to use vos optimized video predictor with all modules compiled",
)
args = parser.parse_args()
# if we use per-object PNG files, they could possibly overlap in inputs and outputs
hydra_overrides_extra = [
"++model.non_overlap_masks=" + ("false" if args.per_obj_png_file else "true")
]
predictor = build_sam2_video_predictor(
config_file=args.sam2_cfg,
ckpt_path=args.sam2_checkpoint,
apply_postprocessing=args.apply_postprocessing,
hydra_overrides_extra=hydra_overrides_extra,
vos_optimized=args.use_vos_optimized_video_predictor,
)
if args.use_all_masks:
print("using all available masks in input_mask_dir as input to the SAM 2 model")
else:
print(
"using only the first frame's mask in input_mask_dir as input to the SAM 2 model"
)
# if a video list file is provided, read the video names from the file
# (otherwise, we use all subdirectories in base_video_dir)
if args.video_list_file is not None:
with open(args.video_list_file, "r") as f:
video_names = [v.strip() for v in f.readlines()]
else:
video_names = [
p
for p in os.listdir(args.base_video_dir)
if os.path.isdir(os.path.join(args.base_video_dir, p))
]
print(f"running VOS prediction on {len(video_names)} videos:\n{video_names}")
for n_video, video_name in enumerate(video_names):
print(f"\n{n_video + 1}/{len(video_names)} - running on {video_name}")
if not args.track_object_appearing_later_in_video:
vos_inference(
predictor=predictor,
base_video_dir=args.base_video_dir,
input_mask_dir=args.input_mask_dir,
output_mask_dir=args.output_mask_dir,
video_name=video_name,
score_thresh=args.score_thresh,
use_all_masks=args.use_all_masks,
per_obj_png_file=args.per_obj_png_file,
)
else:
vos_separate_inference_per_object(
predictor=predictor,
base_video_dir=args.base_video_dir,
input_mask_dir=args.input_mask_dir,
output_mask_dir=args.output_mask_dir,
video_name=video_name,
score_thresh=args.score_thresh,
use_all_masks=args.use_all_masks,
per_obj_png_file=args.per_obj_png_file,
)
print(
f"completed VOS prediction on {len(video_names)} videos -- "
f"output masks saved to {args.output_mask_dir}"
)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment