Commit ef30d662 authored by bailuo's avatar bailuo
Browse files

init

parents
Pipeline #2496 failed with stages
in 0 seconds
from collections import OrderedDict
import torch
from tqdm import tqdm
from projects.llava_sam2.models.extension import SAM2Base
from third_parts.sam2.modeling.sam2_base import NO_OBJ_SCORE
from third_parts.sam2.utils.misc import fill_holes_in_mask_scores
def _obj_id_to_idx(inference_state, obj_id):
"""Map client-side object id to model-side object index."""
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
if obj_idx is not None:
return obj_idx
# This is a new object id not sent to the server before. We only allow adding
# new objects *before* the tracking starts.
allow_new_object = not inference_state["tracking_has_started"]
if allow_new_object:
# get the next object slot
obj_idx = len(inference_state["obj_id_to_idx"])
inference_state["obj_id_to_idx"][obj_id] = obj_idx
inference_state["obj_idx_to_id"][obj_idx] = obj_id
inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
# set up input and output structures for this object
inference_state["point_inputs_per_obj"][obj_idx] = {}
inference_state["mask_inputs_per_obj"][obj_idx] = {}
inference_state["output_dict_per_obj"][obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
inference_state["temp_output_dict_per_obj"][obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
return obj_idx
else:
raise RuntimeError(
f"Cannot add new object id {obj_id} after tracking starts. "
f"All existing object ids: {inference_state['obj_ids']}. "
f"Please call 'reset_state' to restart from scratch."
)
def _get_maskmem_pos_enc(inference_state, current_out):
"""
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
a constant in the inference session to reduce session storage size.
"""
model_constants = inference_state["constants"]
# "out_maskmem_pos_enc" should be either a list of tensors or None
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
if out_maskmem_pos_enc is not None:
if "maskmem_pos_enc" not in model_constants:
assert isinstance(out_maskmem_pos_enc, list)
# only take the slice for one object, since it's same across objects
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
else:
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
# expand the cached maskmem_pos_enc to the actual batch size
batch_size = out_maskmem_pos_enc[0].size(0)
expanded_maskmem_pos_enc = [
x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
]
else:
expanded_maskmem_pos_enc = None
return expanded_maskmem_pos_enc
def _obj_idx_to_id(inference_state, obj_idx):
"""Map model-side object index to client-side object id."""
return inference_state["obj_idx_to_id"][obj_idx]
def _get_obj_num(inference_state):
"""Get the total number of unique object ids received so far in this session."""
return len(inference_state["obj_idx_to_id"])
class SAM2VideoPredictor(SAM2Base):
"""The predictor class to handle user interactions and manage inference states."""
def __init__(
self,
fill_hole_area=0,
# whether to apply non-overlapping constraints on the output object masks
non_overlap_masks=False,
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
clear_non_cond_mem_around_input=False,
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
clear_non_cond_mem_for_multi_obj=False,
**kwargs,
):
super().__init__(**kwargs)
self.fill_hole_area = fill_hole_area
self.non_overlap_masks = non_overlap_masks
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
def _get_image_feature(self, inference_state, frame_idx, batch_size):
"""Compute the image features on a given frame."""
# Look up in the cache first
image, backbone_out = inference_state["cached_features"].get(
frame_idx, (None, None)
)
if backbone_out is None:
# Cache miss -- we will run inference on a single image
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
# expand the features to have the same dimension as the number of objects
expanded_image = image.expand(batch_size, -1, -1, -1)
expanded_backbone_out = {
"backbone_fpn": backbone_out["backbone_fpn"].copy(),
"vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
}
for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
expanded_backbone_out["backbone_fpn"][i] = feat.expand(
batch_size, -1, -1, -1
)
for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
pos = pos.expand(batch_size, -1, -1, -1)
expanded_backbone_out["vision_pos_enc"][i] = pos
features = self._prepare_backbone_features(expanded_backbone_out)
features = (expanded_image,) + features
return features
def _run_single_frame_inference(
self,
inference_state,
output_dict,
frame_idx,
batch_size,
is_init_cond_frame,
point_inputs,
mask_inputs,
reverse,
run_mem_encoder,
prev_sam_mask_logits=None,
## Extension: LLM prompt
language_embd=None,
):
"""Run tracking on a single frame based on current inputs and previous memory."""
# Retrieve correct image features
(
_,
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = self._get_image_feature(inference_state, frame_idx, batch_size)
# point and mask should not appear as input simultaneously on the same frame
assert point_inputs is None or mask_inputs is None
current_out = self.track_step(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats,
current_vision_pos_embeds=current_vision_pos_embeds,
feat_sizes=feat_sizes,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
output_dict=output_dict,
num_frames=inference_state["num_frames"],
track_in_reverse=reverse,
run_mem_encoder=run_mem_encoder,
prev_sam_mask_logits=prev_sam_mask_logits,
language_embd=language_embd,
)
# optionally offload the output to CPU memory to save GPU space
storage_device = inference_state["storage_device"]
maskmem_features = current_out["maskmem_features"]
if maskmem_features is not None:
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
pred_masks_gpu = current_out["pred_masks"]
# potentially fill holes in the predicted masks
if self.fill_hole_area > 0:
pred_masks_gpu = fill_holes_in_mask_scores(
pred_masks_gpu, self.fill_hole_area
)
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
maskmem_pos_enc = _get_maskmem_pos_enc(inference_state, current_out)
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
obj_ptr = current_out["obj_ptr"]
# make a compact version of this frame's output to reduce the state size
compact_current_out = {
"maskmem_features": maskmem_features,
"maskmem_pos_enc": maskmem_pos_enc,
"pred_masks": pred_masks,
"obj_ptr": obj_ptr,
}
return compact_current_out, pred_masks_gpu
def _consolidate_temp_output_across_obj(
self,
inference_state,
frame_idx,
is_cond,
run_mem_encoder,
consolidate_at_video_res=False,
):
"""
Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
a frame into a single output for all objects, including
1) fill any missing objects either from `output_dict_per_obj` (if they exist in
`output_dict_per_obj` for this frame) or leave them as placeholder values
(if they don't exist in `output_dict_per_obj` for this frame);
2) if specified, rerun memory encoder after apply non-overlapping constraints
on the object scores.
"""
batch_size = _get_obj_num(inference_state)
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Optionally, we allow consolidating the temporary outputs at the original
# video resolution (to provide a better editing experience for mask prompts).
if consolidate_at_video_res:
assert not run_mem_encoder, "memory encoder cannot run at video resolution"
consolidated_H = inference_state["video_height"]
consolidated_W = inference_state["video_width"]
consolidated_mask_key = "pred_masks_video_res"
else:
consolidated_H = consolidated_W = self.image_size // 4
consolidated_mask_key = "pred_masks"
# Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
# will be added when rerunning the memory encoder after applying non-overlapping
# constraints to object scores. Its "pred_masks" are prefilled with a large
# negative value (NO_OBJ_SCORE) to represent missing objects.
consolidated_out = {
"maskmem_features": None,
"maskmem_pos_enc": None,
consolidated_mask_key: torch.full(
size=(batch_size, 1, consolidated_H, consolidated_W),
fill_value=NO_OBJ_SCORE,
dtype=torch.float32,
device=inference_state["storage_device"],
),
"obj_ptr": torch.full(
size=(batch_size, self.hidden_dim),
fill_value=NO_OBJ_SCORE,
dtype=torch.float32,
device=inference_state["device"],
),
}
empty_mask_ptr = None
for obj_idx in range(batch_size):
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
out = obj_temp_output_dict[storage_key].get(frame_idx, None)
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
# we fall back and look up its previous output in "output_dict_per_obj".
# We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
# "output_dict_per_obj" to find a previous output for this object.
if out is None:
out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
if out is None:
out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
# If the object doesn't appear in "output_dict_per_obj" either, we skip it
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
# placeholder above) and set its object pointer to be a dummy pointer.
if out is None:
# Fill in dummy object pointers for those objects without any inputs or
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
# i.e. when we need to build the memory for tracking).
if run_mem_encoder:
if empty_mask_ptr is None:
empty_mask_ptr = self._get_empty_mask_ptr(
inference_state, frame_idx
)
# fill object pointer with a dummy pointer (based on an empty mask)
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
continue
# Add the temporary object output mask to consolidated output mask
obj_mask = out["pred_masks"]
consolidated_pred_masks = consolidated_out[consolidated_mask_key]
if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
else:
# Resize first if temporary object mask has a different resolution
resized_obj_mask = torch.nn.functional.interpolate(
obj_mask,
size=consolidated_pred_masks.shape[-2:],
mode="bilinear",
align_corners=False,
)
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
# Optionally, apply non-overlapping constraints on the consolidated scores
# and rerun the memory encoder
if run_mem_encoder:
device = inference_state["device"]
high_res_masks = torch.nn.functional.interpolate(
consolidated_out["pred_masks"].to(device, non_blocking=True),
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
if self.non_overlap_masks_for_mem_enc:
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
inference_state=inference_state,
frame_idx=frame_idx,
batch_size=batch_size,
high_res_masks=high_res_masks,
is_mask_from_pts=True, # these frames are what the user interacted with
)
consolidated_out["maskmem_features"] = maskmem_features
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
return consolidated_out
def _get_orig_video_res_output(self, inference_state, any_res_masks):
"""
Resize the object scores to the original video resolution (video_res_masks)
and apply non-overlapping constraints for final output.
"""
device = inference_state["device"]
video_H = inference_state["video_height"]
video_W = inference_state["video_width"]
any_res_masks = any_res_masks.to(device, non_blocking=True)
if any_res_masks.shape[-2:] == (video_H, video_W):
video_res_masks = any_res_masks
else:
video_res_masks = torch.nn.functional.interpolate(
any_res_masks,
size=(video_H, video_W),
mode="bilinear",
align_corners=False,
)
if self.non_overlap_masks:
video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
return any_res_masks, video_res_masks
def init_state(
self,
images
):
"""Initialize a inference state."""
inference_state = {}
inference_state["images"] = images
inference_state["num_frames"] = len(images)
# whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead
inference_state["offload_video_to_cpu"] = False
# whether to offload the inference state to CPU memory
# turning on this option saves the GPU memory at the cost of a lower tracking fps
# (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
# and from 24 to 21 when tracking two objects)
inference_state["offload_state_to_cpu"] = False
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = self.image_size
inference_state["video_width"] = self.image_size
inference_state["device"] = torch.device("cuda")
inference_state["storage_device"] = torch.device("cuda")
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
# visual features on a small number of recently visited frames for quick interactions
inference_state["cached_features"] = {}
# values that don't change across frames (so we only need to hold one copy of them)
inference_state["constants"] = {}
# mapping between client-side object id and model-side object index
inference_state["obj_id_to_idx"] = OrderedDict()
inference_state["obj_idx_to_id"] = OrderedDict()
inference_state["obj_ids"] = []
# A storage to hold the model's tracking results and states on each frame
inference_state["output_dict"] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
inference_state["output_dict_per_obj"] = {}
# A temporary storage to hold new outputs when user interact with a frame
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
inference_state["temp_output_dict_per_obj"] = {}
# Frames that already holds consolidated outputs from click or mask inputs
# (we directly use their consolidated outputs during tracking)
inference_state["consolidated_frame_inds"] = {
"cond_frame_outputs": set(), # set containing frame indices
"non_cond_frame_outputs": set(), # set containing frame indices
}
# metadata for each tracking frame (e.g. which direction it's tracked)
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"] = {}
return inference_state
def add_language_embd(
self,
inference_state,
frame_idx,
obj_id,
language_embd,
inference=False,
):
obj_idx = _obj_id_to_idx(inference_state, obj_id)
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
# whether to track in reverse time order
if is_init_cond_frame:
reverse = False
else:
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# Add a frame to conditioning output if it's an initial conditioning frame or
# if the model sees all frames receiving clicks/mask as conditioning frames.
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Get any previously predicted mask logits on this object and feed it along with
# the new clicks into the SAM mask decoder.
prev_sam_mask_logits = None
# lookup temporary output dict first, which contains the most recent output
# (if not found, then lookup conditioning and non-conditioning frame output)
prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
if prev_out is None:
prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
if prev_out is None:
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
if prev_out is not None and prev_out["pred_masks"] is not None:
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, pred_mask_gpu = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=obj_output_dict, # run on the slice of a single object
frame_idx=frame_idx,
batch_size=1, # run on the slice of a single object
is_init_cond_frame=is_init_cond_frame,
point_inputs=None,
mask_inputs=None,
reverse=reverse,
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
# allows us to enforce non-overlapping constraints on all objects before encoding
# them into memory.
run_mem_encoder=False,
prev_sam_mask_logits=prev_sam_mask_logits,
## Extension: LLM prompt
language_embd=language_embd,
)
# Add the output to the output dict (to be used as future memory)
obj_temp_output_dict[storage_key][frame_idx] = current_out
# Resize the output mask to the original video resolution
obj_ids = inference_state["obj_ids"]
if inference:
_consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=False,
)
# _, video_res_masks = self._get_orig_video_res_output(
# inference_state, consolidated_out["pred_masks_video_res"]
# )
return frame_idx, obj_ids, pred_mask_gpu
def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
"""
Remove the non-conditioning memory around the input frame. When users provide
correction clicks, the surrounding frames' non-conditioning memories can still
contain outdated object appearance information and could confuse the model.
This method clears those non-conditioning memories surrounding the interacted
frame to avoid giving the model both old and new information about the object.
"""
r = self.memory_temporal_stride_for_eval
frame_idx_begin = frame_idx - r * self.num_maskmem
frame_idx_end = frame_idx + r * self.num_maskmem
output_dict = inference_state["output_dict"]
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
for t in range(frame_idx_begin, frame_idx_end + 1):
non_cond_frame_outputs.pop(t, None)
for obj_output_dict in inference_state["output_dict_per_obj"].values():
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
def _run_memory_encoder(
self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts
):
"""
Run the memory encoder on `high_res_masks`. This is usually after applying
non-overlapping constraints to object scores. Since their scores changed, their
memory also need to be computed again with the memory encoder.
"""
# Retrieve correct image features
_, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
inference_state, frame_idx, batch_size
)
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
current_vision_feats=current_vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_masks,
is_mask_from_pts=is_mask_from_pts,
)
# optionally offload the output to CPU memory to save GPU space
storage_device = inference_state["storage_device"]
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
maskmem_pos_enc = _get_maskmem_pos_enc(
inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
)
return maskmem_features, maskmem_pos_enc
def _add_output_per_object(
self, inference_state, frame_idx, current_out, storage_key
):
"""
Split a multi-object output into per-object output slices and add them into
`output_dict_per_obj`. The resulting slices share the same tensor storage.
"""
maskmem_features = current_out["maskmem_features"]
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
maskmem_pos_enc = current_out["maskmem_pos_enc"]
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
output_dict_per_obj = inference_state["output_dict_per_obj"]
for obj_idx, obj_output_dict in output_dict_per_obj.items():
obj_slice = slice(obj_idx, obj_idx + 1)
obj_out = {
"maskmem_features": None,
"maskmem_pos_enc": None,
"pred_masks": current_out["pred_masks"][obj_slice],
"obj_ptr": current_out["obj_ptr"][obj_slice],
}
if maskmem_features is not None:
obj_out["maskmem_features"] = maskmem_features[obj_slice]
if maskmem_pos_enc is not None:
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
obj_output_dict[storage_key][frame_idx] = obj_out
@torch.inference_mode()
def propagate_in_video_preflight(self, inference_state):
"""Prepare inference_state and consolidate temporary outputs before tracking."""
# Tracking has started and we don't allow adding new objects until session is reset.
inference_state["tracking_has_started"] = True
batch_size = _get_obj_num(inference_state)
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
# add them into "output_dict".
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
output_dict = inference_state["output_dict"]
# "consolidated_frame_inds" contains indices of those frames where consolidated
# temporary outputs have been added (either in this call or any previous calls
# to `propagate_in_video_preflight`).
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
for is_cond in [False, True]:
# Separately consolidate conditioning and non-conditioning temp outptus
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Find all the frames that contain temporary outputs for any objects
# (these should be the frames that have just received clicks for mask inputs
# via `add_new_points` or `add_new_mask`)
temp_frame_inds = set()
for obj_temp_output_dict in temp_output_dict_per_obj.values():
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
consolidated_frame_inds[storage_key].update(temp_frame_inds)
# consolidate the temprary output across all objects on this frame
for frame_idx in temp_frame_inds:
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
)
# merge them into "output_dict" and also create per-object slices
output_dict[storage_key][frame_idx] = consolidated_out
self._add_output_per_object(
inference_state, frame_idx, consolidated_out, storage_key
)
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
if clear_non_cond_mem:
# clear non-conditioning memory of the surrounding frames
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
# clear temporary outputs in `temp_output_dict_per_obj`
for obj_temp_output_dict in temp_output_dict_per_obj.values():
obj_temp_output_dict[storage_key].clear()
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
# output on the same frame in "non_cond_frame_outputs"
for frame_idx in output_dict["cond_frame_outputs"]:
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
for obj_output_dict in inference_state["output_dict_per_obj"].values():
for frame_idx in obj_output_dict["cond_frame_outputs"]:
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
assert frame_idx in output_dict["cond_frame_outputs"]
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
# with either points or mask inputs (which should be true under a correct workflow).
all_consolidated_frame_inds = (
consolidated_frame_inds["cond_frame_outputs"]
| consolidated_frame_inds["non_cond_frame_outputs"]
)
input_frames_inds = set()
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
input_frames_inds.update(point_inputs_per_frame.keys())
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
input_frames_inds.update(mask_inputs_per_frame.keys())
# with language embd as input, there may not be point or box
# assert all_consolidated_frame_inds == input_frames_inds
@torch.inference_mode()
def propagate_in_video(
self,
inference_state,
start_frame_idx=None,
max_frame_num_to_track=None,
reverse=False,
):
"""Propagate the input points across frames to track in the entire video."""
self.propagate_in_video_preflight(inference_state)
output_dict = inference_state["output_dict"]
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
obj_ids = inference_state["obj_ids"]
num_frames = inference_state["num_frames"]
batch_size = _get_obj_num(inference_state)
if len(output_dict["cond_frame_outputs"]) == 0:
raise RuntimeError("No points are provided; please add points first")
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
# set start index, end index, and processing order
if start_frame_idx is None:
# default: start from the earliest frame with input points
start_frame_idx = min(output_dict["cond_frame_outputs"])
if max_frame_num_to_track is None:
# default: track all the frames in the video
max_frame_num_to_track = num_frames
if reverse:
end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
if start_frame_idx > 0:
processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
else:
processing_order = [] # skip reverse tracking if starting from frame 0
else:
end_frame_idx = min(
start_frame_idx + max_frame_num_to_track, num_frames - 1
)
processing_order = range(start_frame_idx, end_frame_idx + 1)
for frame_idx in tqdm(processing_order, desc="propagate in video"):
# We skip those frames already in consolidated outputs (these are frames
# that received input clicks or mask). Note that we cannot directly run
# batched forward on them via `_run_single_frame_inference` because the
# number of clicks on each object might be different.
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
storage_key = "cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
if clear_non_cond_mem:
# clear non-conditioning memory of the surrounding frames
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
storage_key = "non_cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
else:
storage_key = "non_cond_frame_outputs"
current_out, pred_masks = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=output_dict,
frame_idx=frame_idx,
batch_size=batch_size,
is_init_cond_frame=False,
point_inputs=None,
mask_inputs=None,
reverse=reverse,
run_mem_encoder=True,
)
output_dict[storage_key][frame_idx] = current_out
# Create slices of per-object outputs for subsequent interaction with each
# individual object after tracking.
self._add_output_per_object(
inference_state, frame_idx, current_out, storage_key
)
inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
# Resize the output mask to the original video resolution (we directly use
# the mask scores on GPU for output to avoid any CPU conversion in between)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, pred_masks
)
yield frame_idx, obj_ids, video_res_masks
import numpy as np
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
class DirectResize:
def __init__(self, target_length: int) -> None:
self.target_length = target_length
def apply_image(self, image: np.ndarray) -> np.ndarray:
"""
Expects a numpy array with shape HxWxC in uint8 format.
"""
img = to_pil_image(image, mode='RGB')
return np.array(img.resize((self.target_length, self.target_length)))
import os.path
import torch
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
from mmengine.model import BaseModule
from vlm.utils import load_checkpoint_with_prefix, load_state_dict_to_model
BASE_DIR = 'work_dirs/ckpt'
class SAM2(BaseModule):
def __init__(
self,
cfg_path: str = "sam2_hiera_l.yaml",
ckpt_path: str = "sam2_hiera_large.pt",
hydra_overrides_extra=None,
apply_postprocessing=True,
):
super().__init__(init_cfg=None)
import third_parts.sam2 # noqa: F401
if hydra_overrides_extra is None:
hydra_overrides_extra = []
hydra_overrides = [
## Extension: LLM prompt
"++model._target_=projects.llava_sam2.models.predictor.SAM2VideoPredictor",
]
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
# "++model.binarize_mask_from_pts_for_mem_enc=true",
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
# "++model.fill_hole_area=8",
]
hydra_overrides.extend(hydra_overrides_extra)
# Read config and init model
cfg = compose(config_name=cfg_path, overrides=hydra_overrides)
OmegaConf.resolve(cfg)
sam2_model = instantiate(cfg.model, _recursive_=True)
state_dict = load_checkpoint_with_prefix(os.path.join(BASE_DIR, ckpt_path))
load_state_dict_to_model(sam2_model, state_dict)
self.sam2_model = sam2_model
self.hidden_dim = self.sam2_model.hidden_dim
self.img_mean = (0.485, 0.456, 0.406)
self.img_std = (0.229, 0.224, 0.225)
def inject_language_embd(self, inference_state, language_embd):
num_frame = len(language_embd)
num_obj = len(language_embd[0])
mask_out = []
for frame_idx in range(num_frame):
frame_mask_out = []
for obj_idx in range(num_obj):
_language_embd = language_embd[frame_idx][obj_idx][None][None]
_, _, out_mask_logits = self.sam2_model.add_language_embd(inference_state, frame_idx, obj_idx + 100, _language_embd)
frame_mask_out.append(out_mask_logits)
frame_mask_out = torch.cat(frame_mask_out, dim=1)
mask_out.append(frame_mask_out)
mask_out = torch.cat(mask_out, dim=0)
return mask_out
def language_embd_inference(self, inference_state, language_embd):
num_frame = len(language_embd)
num_obj = len(language_embd[0])
mask_out = []
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
for frame_idx in range(num_frame):
frame_mask_out = []
for obj_idx in range(num_obj):
_language_embd = language_embd[frame_idx][obj_idx][None][None]
_, _, out_mask_logits = self.sam2_model.add_language_embd(
inference_state,
frame_idx,
obj_idx + 100,
_language_embd,
inference=True,
)
frame_mask_out.append(out_mask_logits)
frame_mask_out = torch.cat(frame_mask_out, dim=1)
mask_out.append(frame_mask_out)
mask_out = []
for out_frame_idx, out_obj_ids, out_mask_logits in self.sam2_model.propagate_in_video(inference_state):
mask_out.append(out_mask_logits)
mask_out = torch.cat(mask_out, dim=0)
return mask_out
def get_sam2_embeddings(self, images):
return self.sam2_model.init_state(images)
def forward(self, batch):
raise NotImplementedError
def preprocess_image(self, image: torch.Tensor, dtype=torch.float32) -> torch.Tensor:
image = image / 255.
img_mean = torch.tensor(self.img_mean, dtype=dtype, device=image.device)[:, None, None]
img_std = torch.tensor(self.img_std, dtype=dtype, device=image.device)[:, None, None]
image -= img_mean
image /= img_std
return image
import os.path
import torch
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
from mmengine.model import BaseModule
from vlm.utils import load_checkpoint_with_prefix, load_state_dict_to_model
BASE_DIR = 'pretrained/'
class SAM2TrainRunner(BaseModule):
def __init__(
self,
cfg_path: str = "sam2_hiera_l.yaml",
ckpt_path: str = "sam2_hiera_large.pt",
hydra_overrides_extra=None,
apply_postprocessing=True,
):
super().__init__(init_cfg=None)
import third_parts.sam2 # noqa: F401
if hydra_overrides_extra is None:
hydra_overrides_extra = []
hydra_overrides = [
## Extension: LLM prompt
"++model._target_=projects.llava_sam2.models.extension.SAM2Base",
]
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
# "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
# "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
# "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
# "++model.binarize_mask_from_pts_for_mem_enc=true",
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
# "++model.fill_hole_area=8",
]
hydra_overrides.extend(hydra_overrides_extra)
# Read config and init model
cfg = compose(config_name=cfg_path, overrides=hydra_overrides)
OmegaConf.resolve(cfg)
sam2_model = instantiate(cfg.model, _recursive_=True)
state_dict = load_checkpoint_with_prefix(os.path.join(BASE_DIR, ckpt_path))
load_state_dict_to_model(sam2_model, state_dict)
self.sam2_model = sam2_model
self.hidden_dim = self.sam2_model.hidden_dim
self.img_mean = (0.485, 0.456, 0.406)
self.img_std = (0.229, 0.224, 0.225)
def preprocess_image(self, image: torch.Tensor) -> torch.Tensor:
image = image / 255.
img_mean = torch.tensor(self.img_mean, dtype=image.dtype, device=image.device)[:, None, None]
img_std = torch.tensor(self.img_std, dtype=image.dtype, device=image.device)[:, None, None]
image -= img_mean
image /= img_std
return image
def inject_language_embd(self, sam_states, language_embd, nf_nobj=None):
high_res_features = [
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
for x, s in zip(sam_states['current_vision_feats'][:-1], sam_states['feat_sizes'][:-1])
]
B = sam_states['current_vision_feats'][-1].size(1) # batch size on this frame
C = self.hidden_dim
H, W = sam_states['feat_sizes'][-1]
if self.sam2_model.directly_add_no_mem_embed:
# directly add no-mem embedding (instead of using the transformer encoder)
pix_feat_with_mem = sam_states['current_vision_feats'][-1] + self.sam2_model.no_mem_embed
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
else:
raise NotImplementedError("directly add no memory embedding is not implemented")
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
_, _, _, low_res_masks, high_res_masks, obj_ptr, _, = self.sam2_model._forward_sam_heads(
backbone_features=pix_feat_with_mem,
point_inputs=None,
mask_inputs=None,
high_res_features=high_res_features,
multimask_output=self.sam2_model._use_multimask(is_init_cond_frame=True, point_inputs=None),
# Inject language Embed if possible
language_embd=language_embd,
)
if nf_nobj is not None:
pred_masks = low_res_masks.squeeze(1)
pred_masks = pred_masks.unflatten(0, nf_nobj)
else:
pred_masks = low_res_masks
return pred_masks
def get_sam2_embeddings(self, images, expand_size=1):
# Step 1: inference the backbone with the images
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
feats = self.sam2_model.forward_image(images)
if expand_size > 1:
# feats['vision_features'] = feats['vision_features'][:, None].expand(-1, expand_size, -1, -1, -1).flatten(0, 1)
for i, feat in enumerate(feats["backbone_fpn"]):
feats["backbone_fpn"][i] = feat[:, None].expand(-1, expand_size, -1, -1, -1).flatten(0, 1)
for i, pos in enumerate(feats["vision_pos_enc"]):
pos = pos[:, None].expand(-1, expand_size, -1, -1, -1).flatten(0, 1)
feats["vision_pos_enc"][i] = pos
# Step 2: Process the features to output
_, current_vision_feats, current_vision_pos_embeds, feat_sizes = self.sam2_model._prepare_backbone_features(feats)
return {
"current_vision_feats": current_vision_feats,
"current_vision_pos_embeds": current_vision_pos_embeds,
"feat_sizes": feat_sizes,
}
def forward(self, batch):
raise NotImplementedError
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image,
min_num=1,
max_num=6,
image_size=448,
use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1) for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
target_ratios, orig_width,
orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = ((i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
\ No newline at end of file
transformers==4.42.3
xtuner[deepspeed]==0.1.23
timm==1.0.9
mmdet==3.3.0
hydra-core==1.3.2
ninja==1.11.1
decord==0.6.0
peft==0.11.1
gradio==4.44.0
\ No newline at end of file
File added
from .video_io import VideoReader
# Copyright (c) OpenMMLab. All rights reserved.
import collections
import os.path as osp
import random
from typing import Dict, List
import mmengine
from mmengine.dataset import BaseDataset
# from mmdet.registry import DATASETS
# @DATASETS.register_module()
class RefCocoDataset(BaseDataset):
"""RefCOCO dataset.
The `Refcoco` and `Refcoco+` dataset is based on
`ReferItGame: Referring to Objects in Photographs of Natural Scenes
<http://tamaraberg.com/papers/referit.pdf>`_.
The `Refcocog` dataset is based on
`Generation and Comprehension of Unambiguous Object Descriptions
<https://arxiv.org/abs/1511.02283>`_.
Args:
ann_file (str): Annotation file path.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (str): Prefix for training data.
split_file (str): Split file path.
split (str): Split name. Defaults to 'train'.
text_mode (str): Text mode. Defaults to 'random'.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
ann_file: str,
split_file: str,
data_prefix: Dict,
split: str = 'train',
text_mode: str = 'random',
**kwargs):
self.split_file = split_file
self.split = split
assert text_mode in ['original', 'random', 'concat', 'select_first']
self.text_mode = text_mode
super().__init__(
data_root=data_root,
data_prefix=data_prefix,
ann_file=ann_file,
**kwargs,
)
def _join_prefix(self):
if not mmengine.is_abs(self.split_file) and self.split_file:
self.split_file = osp.join(self.data_root, self.split_file)
return super()._join_prefix()
def _init_refs(self):
"""Initialize the refs for RefCOCO."""
anns, imgs = {}, {}
for ann in self.instances['annotations']:
anns[ann['id']] = ann
for img in self.instances['images']:
imgs[img['id']] = img
refs, ref_to_ann = {}, {}
for ref in self.splits:
# ids
ref_id = ref['ref_id']
ann_id = ref['ann_id']
# add mapping related to ref
refs[ref_id] = ref
ref_to_ann[ref_id] = anns[ann_id]
self.refs = refs
self.ref_to_ann = ref_to_ann
def load_data_list(self) -> List[dict]:
"""Load data list."""
self.splits = mmengine.load(self.split_file, file_format='pkl')
self.instances = mmengine.load(self.ann_file, file_format='json')
self._init_refs()
img_prefix = self.data_prefix['img_path']
ref_ids = [
ref['ref_id'] for ref in self.splits if ref['split'] == self.split
]
full_anno = []
for ref_id in ref_ids:
ref = self.refs[ref_id]
ann = self.ref_to_ann[ref_id]
ann.update(ref)
full_anno.append(ann)
image_id_list = []
final_anno = {}
for anno in full_anno:
image_id_list.append(anno['image_id'])
final_anno[anno['ann_id']] = anno
annotations = [value for key, value in final_anno.items()]
coco_train_id = []
image_annot = {}
for i in range(len(self.instances['images'])):
coco_train_id.append(self.instances['images'][i]['id'])
image_annot[self.instances['images'][i]
['id']] = self.instances['images'][i]
images = []
for image_id in list(set(image_id_list)):
images += [image_annot[image_id]]
data_list = []
grounding_dict = collections.defaultdict(list)
for anno in annotations:
image_id = int(anno['image_id'])
grounding_dict[image_id].append(anno)
join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
for image in images:
img_id = image['id']
instances = []
sentences = []
for grounding_anno in grounding_dict[img_id]:
texts = [x['raw'].lower() for x in grounding_anno['sentences']]
# random select one text
if self.text_mode == 'random':
idx = random.randint(0, len(texts) - 1)
text = [texts[idx]]
# concat all texts
elif self.text_mode == 'concat':
text = [''.join(texts)]
# select the first text
elif self.text_mode == 'select_first':
text = [texts[0]]
# use all texts
elif self.text_mode == 'original':
text = texts
else:
raise ValueError(f'Invalid text mode "{self.text_mode}".')
ins = [{
'mask': grounding_anno['segmentation'],
'ignore_flag': 0
}] * len(text)
instances.extend(ins)
sentences.extend(text)
data_info = {
'img_path': join_path(img_prefix, image['file_name']),
'img_id': img_id,
'instances': instances,
'text': sentences
}
data_list.append(data_info)
if len(data_list) == 0:
raise ValueError(f'No sample in split "{self.split}".')
return data_list
from .cross_entropy_loss import CrossEntropyLoss
from .dice_loss import DiceLoss
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
def accuracy(pred, target, topk=1, thresh=None):
"""Calculate accuracy according to the prediction and target.
Args:
pred (torch.Tensor): The model prediction, shape (N, num_class)
target (torch.Tensor): The target of each prediction, shape (N, )
topk (int | tuple[int], optional): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thresh (float, optional): If not None, predictions with scores under
this threshold are considered incorrect. Default to None.
Returns:
float | tuple[float]: If the input ``topk`` is a single integer,
the function will return a single float as accuracy. If
``topk`` is a tuple containing multiple integers, the
function will return a tuple containing accuracies of
each ``topk`` number.
"""
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
topk = (topk, )
return_single = True
else:
return_single = False
maxk = max(topk)
if pred.size(0) == 0:
accu = [pred.new_tensor(0.) for i in range(len(topk))]
return accu[0] if return_single else accu
assert pred.ndim == 2 and target.ndim == 1
assert pred.size(0) == target.size(0)
assert maxk <= pred.size(1), \
f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
pred_value, pred_label = pred.topk(maxk, dim=1)
pred_label = pred_label.t() # transpose to shape (maxk, N)
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
if thresh is not None:
# Only prediction values larger than thresh are counted as correct
correct = correct & (pred_value > thresh).t()
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / pred.size(0)))
return res[0] if return_single else res
class Accuracy(nn.Module):
def __init__(self, topk=(1, ), thresh=None):
"""Module to calculate the accuracy.
Args:
topk (tuple, optional): The criterion used to calculate the
accuracy. Defaults to (1,).
thresh (float, optional): If not None, predictions with scores
under this threshold are considered incorrect. Default to None.
"""
super().__init__()
self.topk = topk
self.thresh = thresh
def forward(self, pred, target):
"""Forward function to calculate accuracy.
Args:
pred (torch.Tensor): Prediction of models.
target (torch.Tensor): Target for each prediction.
Returns:
tuple[float]: The accuracies under different topk criterions.
"""
return accuracy(pred, target, self.topk, self.thresh)
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
# from mmdet.registry import MODELS
from .accuracy import accuracy
from .utils import weight_reduce_loss
def cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False):
"""Calculate the CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index
# element-wise losses
loss = F.cross_entropy(
pred,
label,
weight=class_weight,
reduction='none',
ignore_index=ignore_index)
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = label.numel() - (label == ignore_index).sum().item()
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(
valid_mask & (labels < label_channels), as_tuple=False)
if inds.numel() > 0:
bin_labels[inds, labels[inds]] = 1
valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),
label_channels).float()
if label_weights is None:
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
bin_label_weights *= valid_mask
return bin_labels, bin_label_weights, valid_mask
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1) or (N, ).
When the shape of pred is (N, 1), label will be expanded to
one-hot format, and when the shape of pred is (N, ), label
will not be expanded to one-hot format.
label (torch.Tensor): The learning label of the prediction,
with shape (N, ).
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss.
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index
if pred.dim() != label.dim():
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.size(-1), ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
# The inplace writing method will have a mismatched broadcast
# shape error if the weight and valid_mask dimensions
# are inconsistent such as (B,N,1) and (B,N,C).
weight = weight * valid_mask
else:
weight = valid_mask
# average loss over non-ignored elements
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = valid_mask.sum().item()
# weighted element-wise losses
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
loss = weight_reduce_loss(
loss, weight, reduction=reduction, avg_factor=avg_factor)
return loss
def mask_cross_entropy(pred,
target,
label,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=None,
**kwargs):
"""Calculate the CrossEntropy loss for masks.
Args:
pred (torch.Tensor): The prediction with shape (N, C, *), C is the
number of classes. The trailing * indicates arbitrary shape.
target (torch.Tensor): The learning label of the prediction.
label (torch.Tensor): ``label`` indicates the class label of the mask
corresponding object. This will be used to select the mask in the
of the class which the object belongs to when the mask prediction
if not class-agnostic.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Returns:
torch.Tensor: The calculated loss
Example:
>>> N, C = 3, 11
>>> H, W = 2, 2
>>> pred = torch.randn(N, C, H, W) * 1000
>>> target = torch.rand(N, H, W)
>>> label = torch.randint(0, C, size=(N,))
>>> reduction = 'mean'
>>> avg_factor = None
>>> class_weights = None
>>> loss = mask_cross_entropy(pred, target, label, reduction,
>>> avg_factor, class_weights)
>>> assert loss.shape == (1,)
"""
assert ignore_index is None, 'BCE loss does not support ignore_index'
# TODO: handle these two reserved arguments
assert reduction == 'mean' and avg_factor is None
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, weight=class_weight, reduction='mean')[None]
# @MODELS.register_module()
class CrossEntropyLoss(nn.Module):
def __init__(self,
use_sigmoid=False,
use_mask=False,
reduction='mean',
class_weight=None,
ignore_index=None,
loss_weight=1.0,
avg_non_ignore=False):
"""CrossEntropyLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float], optional): Weight of each class.
Defaults to None.
ignore_index (int | None): The label index to be ignored.
Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
"""
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.ignore_index = ignore_index
self.avg_non_ignore = avg_non_ignore
if ((ignore_index is not None) and not self.avg_non_ignore
and self.reduction == 'mean'):
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
def extra_repr(self):
"""Extra repr."""
s = f'avg_non_ignore={self.avg_non_ignore}'
return s
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
ignore_index=None,
**kwargs):
"""Forward function.
Args:
cls_score (torch.Tensor): The prediction.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The method used to reduce the
loss. Options are "none", "mean" and "sum".
ignore_index (int | None): The label index to be ignored.
If not None, it will override the default value. Default: None.
Returns:
torch.Tensor: The calculated loss.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if ignore_index is None:
ignore_index = self.ignore_index
if self.class_weight is not None:
class_weight = cls_score.new_tensor(
self.class_weight, device=cls_score.device)
else:
class_weight = None
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
ignore_index=ignore_index,
avg_non_ignore=self.avg_non_ignore,
**kwargs)
return loss_cls
# @MODELS.register_module()
class CrossEntropyCustomLoss(CrossEntropyLoss):
def __init__(self,
use_sigmoid=False,
use_mask=False,
reduction='mean',
num_classes=-1,
class_weight=None,
ignore_index=None,
loss_weight=1.0,
avg_non_ignore=False):
"""CrossEntropyCustomLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
num_classes (int): Number of classes to classify.
class_weight (list[float], optional): Weight of each class.
Defaults to None.
ignore_index (int | None): The label index to be ignored.
Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
"""
super(CrossEntropyCustomLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.ignore_index = ignore_index
self.avg_non_ignore = avg_non_ignore
if ((ignore_index is not None) and not self.avg_non_ignore
and self.reduction == 'mean'):
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
self.num_classes = num_classes
assert self.num_classes != -1
# custom output channels of the classifier
self.custom_cls_channels = True
# custom activation of cls_score
self.custom_activation = True
# custom accuracy of the classsifier
self.custom_accuracy = True
def get_cls_channels(self, num_classes):
assert num_classes == self.num_classes
if not self.use_sigmoid:
return num_classes + 1
else:
return num_classes
def get_activation(self, cls_score):
fine_cls_score = cls_score[:, :self.num_classes]
if not self.use_sigmoid:
bg_score = cls_score[:, [-1]]
new_score = torch.cat([fine_cls_score, bg_score], dim=-1)
scores = F.softmax(new_score, dim=-1)
else:
score_classes = fine_cls_score.sigmoid()
score_neg = 1 - score_classes.sum(dim=1, keepdim=True)
score_neg = score_neg.clamp(min=0, max=1)
scores = torch.cat([score_classes, score_neg], dim=1)
return scores
def get_accuracy(self, cls_score, labels):
fine_cls_score = cls_score[:, :self.num_classes]
pos_inds = labels < self.num_classes
acc_classes = accuracy(fine_cls_score[pos_inds], labels[pos_inds])
acc = dict()
acc['acc_classes'] = acc_classes
return acc
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
# from mmdet.registry import MODELS
from .utils import weight_reduce_loss
def dice_loss(pred,
target,
weight=None,
eps=1e-3,
reduction='mean',
naive_dice=False,
avg_factor=None):
"""Calculate dice loss, there are two forms of dice loss is supported:
- the one proposed in `V-Net: Fully Convolutional Neural
Networks for Volumetric Medical Image Segmentation
<https://arxiv.org/abs/1606.04797>`_.
- the dice loss in which the power of the number in the
denominator is the first power instead of the second
power.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *)
target (torch.Tensor): The learning label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
eps (float): Avoid dividing by zero. Default: 1e-3.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
Options are "none", "mean" and "sum".
naive_dice (bool, optional): If false, use the dice
loss defined in the V-Net paper, otherwise, use the
naive dice loss in which the power of the number in the
denominator is the first power instead of the second
power.Defaults to False.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
input = pred.flatten(1)
target = target.flatten(1).float()
a = torch.sum(input * target, 1)
if naive_dice:
b = torch.sum(input, 1)
c = torch.sum(target, 1)
d = (2 * a + eps) / (b + c + eps)
else:
b = torch.sum(input * input, 1) + eps
c = torch.sum(target * target, 1) + eps
d = (2 * a) / (b + c)
loss = 1 - d
if weight is not None:
assert weight.ndim == loss.ndim
assert len(weight) == len(pred)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
# @MODELS.register_module()
class DiceLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=False,
loss_weight=1.0,
eps=1e-3):
"""Compute dice loss.
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
activate (bool): Whether to activate the predictions inside,
this will disable the inside sigmoid operation.
Defaults to True.
reduction (str, optional): The method used
to reduce the loss. Options are "none",
"mean" and "sum". Defaults to 'mean'.
naive_dice (bool, optional): If false, use the dice
loss defined in the V-Net paper, otherwise, use the
naive dice loss in which the power of the number in the
denominator is the first power instead of the second
power. Defaults to False.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
eps (float): Avoid dividing by zero. Defaults to 1e-3.
"""
super(DiceLoss, self).__init__()
self.use_sigmoid = use_sigmoid
self.reduction = reduction
self.naive_dice = naive_dice
self.loss_weight = loss_weight
self.eps = eps
self.activate = activate
def forward(self,
pred,
target,
weight=None,
reduction_override=None,
avg_factor=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *).
target (torch.Tensor): The label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.activate:
if self.use_sigmoid:
pred = pred.sigmoid()
else:
raise NotImplementedError
loss = self.loss_weight * dice_loss(
pred,
target,
weight,
eps=self.eps,
reduction=reduction,
naive_dice=self.naive_dice,
avg_factor=avg_factor)
return loss
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from typing import Callable, Optional
import torch
import torch.nn.functional as F
from torch import Tensor
def reduce_loss(loss: Tensor, reduction: str) -> Tensor:
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
Return:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
def weight_reduce_loss(loss: Tensor,
weight: Optional[Tensor] = None,
reduction: str = 'mean',
avg_factor: Optional[float] = None) -> Tensor:
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Optional[Tensor], optional): Element-wise weights.
Defaults to None.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
avg_factor (Optional[float], optional): Average factor when
computing the mean of losses. Defaults to None.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = torch.finfo(torch.float32).eps
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
def weighted_loss(loss_func: Callable) -> Callable:
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, avg_factor=2)
tensor(1.5000)
"""
@functools.wraps(loss_func)
def wrapper(pred: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
reduction: str = 'mean',
avg_factor: Optional[int] = None,
**kwargs) -> Tensor:
"""
Args:
pred (Tensor): The prediction.
target (Tensor): Target bboxes.
weight (Optional[Tensor], optional): The weight of loss for each
prediction. Defaults to None.
reduction (str, optional): Options are "none", "mean" and "sum".
Defaults to 'mean'.
avg_factor (Optional[int], optional): Average factor that is used
to average the loss. Defaults to None.
Returns:
Tensor: Loss tensor.
"""
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
return wrapper
from .point_sample import get_uncertain_point_coords_with_randomness
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import point_sample
from torch import Tensor
def get_uncertainty(mask_preds: Tensor, labels: Tensor) -> Tensor:
"""Estimate uncertainty based on pred logits.
We estimate uncertainty as L1 distance between 0.0 and the logits
prediction in 'mask_preds' for the foreground class in `classes`.
Args:
mask_preds (Tensor): mask predication logits, shape (num_rois,
num_classes, mask_height, mask_width).
labels (Tensor): Either predicted or ground truth label for
each predicted mask, of length num_rois.
Returns:
scores (Tensor): Uncertainty scores with the most uncertain
locations having the highest uncertainty score,
shape (num_rois, 1, mask_height, mask_width)
"""
if mask_preds.shape[1] == 1:
gt_class_logits = mask_preds.clone()
else:
inds = torch.arange(mask_preds.shape[0], device=mask_preds.device)
gt_class_logits = mask_preds[inds, labels].unsqueeze(1)
return -torch.abs(gt_class_logits)
def get_uncertain_point_coords_with_randomness(
mask_preds: Tensor, labels: Tensor, num_points: int,
oversample_ratio: float, importance_sample_ratio: float) -> Tensor:
"""Get ``num_points`` most uncertain points with random points during
train.
Sample points in [0, 1] x [0, 1] coordinate space based on their
uncertainty. The uncertainties are calculated for each point using
'get_uncertainty()' function that takes point's logit prediction as
input.
Args:
mask_preds (Tensor): A tensor of shape (num_rois, num_classes,
mask_height, mask_width) for class-specific or class-agnostic
prediction.
labels (Tensor): The ground truth class for each instance.
num_points (int): The number of points to sample.
oversample_ratio (float): Oversampling parameter.
importance_sample_ratio (float): Ratio of points that are sampled
via importnace sampling.
Returns:
point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
that contains the coordinates sampled points.
"""
assert oversample_ratio >= 1
assert 0 <= importance_sample_ratio <= 1
batch_size = mask_preds.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(
batch_size, num_sampled, 2, device=mask_preds.device)
point_logits = point_sample(mask_preds, point_coords)
# It is crucial to calculate uncertainty based on the sampled
# prediction value for the points. Calculating uncertainties of the
# coarse predictions first and sampling them for points leads to
# incorrect results. To illustrate this: assume uncertainty func(
# logits)=-abs(logits), a sampled point between two coarse
# predictions with -1 and 1 logits has 0 logits, and therefore 0
# uncertainty value. However, if we calculate uncertainties for the
# coarse predictions first, both will have -1 uncertainty,
# and sampled point will get -1 uncertainty.
point_uncertainties = get_uncertainty(point_logits, labels)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_sampled * torch.arange(
batch_size, dtype=torch.long, device=mask_preds.device)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
batch_size, num_uncertain_points, 2)
if num_random_points > 0:
rand_roi_coords = torch.rand(
batch_size, num_random_points, 2, device=mask_preds.device)
point_coords = torch.cat((point_coords, rand_roi_coords), dim=1)
return point_coords
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from hydra import initialize_config_module
initialize_config_module("third_parts.sam2.sam2_configs", version_base="1.2")
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
from third_parts.sam2.modeling.sam2_base import SAM2Base
from third_parts.sam2.sam2_image_predictor import SAM2ImagePredictor
from third_parts.sam2.utils.amg import (
area_from_rle,
batch_iterator,
batched_mask_to_box,
box_xyxy_to_xywh,
build_all_layer_point_grids,
calculate_stability_score,
coco_encode_rle,
generate_crop_boxes,
is_box_near_crop_edge,
mask_to_rle_pytorch,
MaskData,
remove_small_regions,
rle_to_mask,
uncrop_boxes_xyxy,
uncrop_masks,
uncrop_points,
)
class SAM2AutomaticMaskGenerator:
def __init__(
self,
model: SAM2Base,
points_per_side: Optional[int] = 32,
points_per_batch: int = 64,
pred_iou_thresh: float = 0.8,
stability_score_thresh: float = 0.95,
stability_score_offset: float = 1.0,
mask_threshold: float = 0.0,
box_nms_thresh: float = 0.7,
crop_n_layers: int = 0,
crop_nms_thresh: float = 0.7,
crop_overlap_ratio: float = 512 / 1500,
crop_n_points_downscale_factor: int = 1,
point_grids: Optional[List[np.ndarray]] = None,
min_mask_region_area: int = 0,
output_mode: str = "binary_mask",
use_m2m: bool = False,
multimask_output: bool = True,
) -> None:
"""
Using a SAM 2 model, generates masks for the entire image.
Generates a grid of point prompts over the image, then filters
low quality and duplicate masks. The default settings are chosen
for SAM 2 with a HieraL backbone.
Arguments:
model (Sam): The SAM 2 model to use for mask prediction.
points_per_side (int or None): The number of points to be sampled
along one side of the image. The total number of points is
points_per_side**2. If None, 'point_grids' must provide explicit
point sampling.
points_per_batch (int): Sets the number of points run simultaneously
by the model. Higher numbers may be faster but use more GPU memory.
pred_iou_thresh (float): A filtering threshold in [0,1], using the
model's predicted mask quality.
stability_score_thresh (float): A filtering threshold in [0,1], using
the stability of the mask under changes to the cutoff used to binarize
the model's mask predictions.
stability_score_offset (float): The amount to shift the cutoff when
calculated the stability score.
mask_threshold (float): Threshold for binarizing the mask logits
box_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks.
crop_n_layers (int): If >0, mask prediction will be run again on
crops of the image. Sets the number of layers to run, where each
layer has 2**i_layer number of image crops.
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks between different crops.
crop_overlap_ratio (float): Sets the degree to which crops overlap.
In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
crop_n_points_downscale_factor (int): The number of points-per-side
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
point_grids (list(np.ndarray) or None): A list over explicit grids
of points used for sampling, normalized to [0,1]. The nth grid in the
list is used in the nth crop layer. Exclusive with points_per_side.
min_mask_region_area (int): If >0, postprocessing will be applied
to remove disconnected regions and holes in masks with area smaller
than min_mask_region_area. Requires opencv.
output_mode (str): The form masks are returned in. Can be 'binary_mask',
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
For large resolutions, 'binary_mask' may consume large amounts of
memory.
use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
multimask_output (bool): Whether to output multimask at each point of the grid.
"""
assert (points_per_side is None) != (
point_grids is None
), "Exactly one of points_per_side or point_grid must be provided."
if points_per_side is not None:
self.point_grids = build_all_layer_point_grids(
points_per_side,
crop_n_layers,
crop_n_points_downscale_factor,
)
elif point_grids is not None:
self.point_grids = point_grids
else:
raise ValueError("Can't have both points_per_side and point_grid be None.")
assert output_mode in [
"binary_mask",
"uncompressed_rle",
"coco_rle",
], f"Unknown output_mode {output_mode}."
if output_mode == "coco_rle":
try:
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
except ImportError as e:
print("Please install pycocotools")
raise e
self.predictor = SAM2ImagePredictor(
model,
max_hole_area=min_mask_region_area,
max_sprinkle_area=min_mask_region_area,
)
self.points_per_batch = points_per_batch
self.pred_iou_thresh = pred_iou_thresh
self.stability_score_thresh = stability_score_thresh
self.stability_score_offset = stability_score_offset
self.mask_threshold = mask_threshold
self.box_nms_thresh = box_nms_thresh
self.crop_n_layers = crop_n_layers
self.crop_nms_thresh = crop_nms_thresh
self.crop_overlap_ratio = crop_overlap_ratio
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
self.min_mask_region_area = min_mask_region_area
self.output_mode = output_mode
self.use_m2m = use_m2m
self.multimask_output = multimask_output
@torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
"""
Generates masks for the given image.
Arguments:
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
Returns:
list(dict(str, any)): A list over records for masks. Each record is
a dict containing the following keys:
segmentation (dict(str, any) or np.ndarray): The mask. If
output_mode='binary_mask', is an array of shape HW. Otherwise,
is a dictionary containing the RLE.
bbox (list(float)): The box around the mask, in XYWH format.
area (int): The area in pixels of the mask.
predicted_iou (float): The model's own prediction of the mask's
quality. This is filtered by the pred_iou_thresh parameter.
point_coords (list(list(float))): The point coordinates input
to the model to generate this mask.
stability_score (float): A measure of the mask's quality. This
is filtered on using the stability_score_thresh parameter.
crop_box (list(float)): The crop of the image used to generate
the mask, given in XYWH format.
"""
# Generate masks
mask_data = self._generate_masks(image)
# Encode masks
if self.output_mode == "coco_rle":
mask_data["segmentations"] = [
coco_encode_rle(rle) for rle in mask_data["rles"]
]
elif self.output_mode == "binary_mask":
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
else:
mask_data["segmentations"] = mask_data["rles"]
# Write mask records
curr_anns = []
for idx in range(len(mask_data["segmentations"])):
ann = {
"segmentation": mask_data["segmentations"][idx],
"area": area_from_rle(mask_data["rles"][idx]),
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
"predicted_iou": mask_data["iou_preds"][idx].item(),
"point_coords": [mask_data["points"][idx].tolist()],
"stability_score": mask_data["stability_score"][idx].item(),
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
}
curr_anns.append(ann)
return curr_anns
def _generate_masks(self, image: np.ndarray) -> MaskData:
orig_size = image.shape[:2]
crop_boxes, layer_idxs = generate_crop_boxes(
orig_size, self.crop_n_layers, self.crop_overlap_ratio
)
# Iterate over image crops
data = MaskData()
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
data.cat(crop_data)
# Remove duplicate masks between crops
if len(crop_boxes) > 1:
# Prefer masks from smaller crops
scores = 1 / box_area(data["crop_boxes"])
scores = scores.to(data["boxes"].device)
keep_by_nms = batched_nms(
data["boxes"].float(),
scores,
torch.zeros_like(data["boxes"][:, 0]), # categories
iou_threshold=self.crop_nms_thresh,
)
data.filter(keep_by_nms)
data.to_numpy()
return data
def _process_crop(
self,
image: np.ndarray,
crop_box: List[int],
crop_layer_idx: int,
orig_size: Tuple[int, ...],
) -> MaskData:
# Crop the image and calculate embeddings
x0, y0, x1, y1 = crop_box
cropped_im = image[y0:y1, x0:x1, :]
cropped_im_size = cropped_im.shape[:2]
self.predictor.set_image(cropped_im)
# Get points for this crop
points_scale = np.array(cropped_im_size)[None, ::-1]
points_for_image = self.point_grids[crop_layer_idx] * points_scale
# Generate masks for this crop in batches
data = MaskData()
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
batch_data = self._process_batch(
points, cropped_im_size, crop_box, orig_size, normalize=True
)
data.cat(batch_data)
del batch_data
self.predictor.reset_predictor()
# Remove duplicates within this crop.
keep_by_nms = batched_nms(
data["boxes"].float(),
data["iou_preds"],
torch.zeros_like(data["boxes"][:, 0]), # categories
iou_threshold=self.box_nms_thresh,
)
data.filter(keep_by_nms)
# Return to the original image frame
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
data["points"] = uncrop_points(data["points"], crop_box)
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
return data
def _process_batch(
self,
points: np.ndarray,
im_size: Tuple[int, ...],
crop_box: List[int],
orig_size: Tuple[int, ...],
normalize=False,
) -> MaskData:
orig_h, orig_w = orig_size
# Run model on this batch
points = torch.as_tensor(points, device=self.predictor.device)
in_points = self.predictor._transforms.transform_coords(
points, normalize=normalize, orig_hw=im_size
)
in_labels = torch.ones(
in_points.shape[0], dtype=torch.int, device=in_points.device
)
masks, iou_preds, low_res_masks = self.predictor._predict(
in_points[:, None, :],
in_labels[:, None],
multimask_output=self.multimask_output,
return_logits=True,
)
# Serialize predictions and store in MaskData
data = MaskData(
masks=masks.flatten(0, 1),
iou_preds=iou_preds.flatten(0, 1),
points=points.repeat_interleave(masks.shape[1], dim=0),
low_res_masks=low_res_masks.flatten(0, 1),
)
del masks
if not self.use_m2m:
# Filter by predicted IoU
if self.pred_iou_thresh > 0.0:
keep_mask = data["iou_preds"] > self.pred_iou_thresh
data.filter(keep_mask)
# Calculate and filter by stability score
data["stability_score"] = calculate_stability_score(
data["masks"], self.mask_threshold, self.stability_score_offset
)
if self.stability_score_thresh > 0.0:
keep_mask = data["stability_score"] >= self.stability_score_thresh
data.filter(keep_mask)
else:
# One step refinement using previous mask predictions
in_points = self.predictor._transforms.transform_coords(
data["points"], normalize=normalize, orig_hw=im_size
)
labels = torch.ones(
in_points.shape[0], dtype=torch.int, device=in_points.device
)
masks, ious = self.refine_with_m2m(
in_points, labels, data["low_res_masks"], self.points_per_batch
)
data["masks"] = masks.squeeze(1)
data["iou_preds"] = ious.squeeze(1)
if self.pred_iou_thresh > 0.0:
keep_mask = data["iou_preds"] > self.pred_iou_thresh
data.filter(keep_mask)
data["stability_score"] = calculate_stability_score(
data["masks"], self.mask_threshold, self.stability_score_offset
)
if self.stability_score_thresh > 0.0:
keep_mask = data["stability_score"] >= self.stability_score_thresh
data.filter(keep_mask)
# Threshold masks and calculate boxes
data["masks"] = data["masks"] > self.mask_threshold
data["boxes"] = batched_mask_to_box(data["masks"])
# Filter boxes that touch crop boundaries
keep_mask = ~is_box_near_crop_edge(
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
)
if not torch.all(keep_mask):
data.filter(keep_mask)
# Compress to RLE
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
data["rles"] = mask_to_rle_pytorch(data["masks"])
del data["masks"]
return data
@staticmethod
def postprocess_small_regions(
mask_data: MaskData, min_area: int, nms_thresh: float
) -> MaskData:
"""
Removes small disconnected regions and holes in masks, then reruns
box NMS to remove any new duplicates.
Edits mask_data in place.
Requires open-cv as a dependency.
"""
if len(mask_data["rles"]) == 0:
return mask_data
# Filter small disconnected regions and holes
new_masks = []
scores = []
for rle in mask_data["rles"]:
mask = rle_to_mask(rle)
mask, changed = remove_small_regions(mask, min_area, mode="holes")
unchanged = not changed
mask, changed = remove_small_regions(mask, min_area, mode="islands")
unchanged = unchanged and not changed
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
# Give score=0 to changed masks and score=1 to unchanged masks
# so NMS will prefer ones that didn't need postprocessing
scores.append(float(unchanged))
# Recalculate boxes and remove any new duplicates
masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(masks)
keep_by_nms = batched_nms(
boxes.float(),
torch.as_tensor(scores),
torch.zeros_like(boxes[:, 0]), # categories
iou_threshold=nms_thresh,
)
# Only recalculate RLEs for masks that have changed
for i_mask in keep_by_nms:
if scores[i_mask] == 0.0:
mask_torch = masks[i_mask].unsqueeze(0)
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
mask_data.filter(keep_by_nms)
return mask_data
def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
new_masks = []
new_iou_preds = []
for cur_points, cur_point_labels, low_res_mask in batch_iterator(
points_per_batch, points, point_labels, low_res_masks
):
best_masks, best_iou_preds, _ = self.predictor._predict(
cur_points[:, None, :],
cur_point_labels[:, None],
mask_input=low_res_mask[:, None, :],
multimask_output=False,
return_logits=True,
)
new_masks.append(best_masks)
new_iou_preds.append(best_iou_preds)
masks = torch.cat(new_masks, dim=0)
return masks, torch.cat(new_iou_preds, dim=0)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
def build_sam2(
config_file,
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
):
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
]
# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
OmegaConf.resolve(cfg)
model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path)
model = model.to(device)
if mode == "eval":
model.eval()
return model
def build_sam2_video_predictor(
config_file,
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
):
hydra_overrides = [
"++model._target_=third_parts.sam2.sam2_video_predictor.SAM2VideoPredictor",
]
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
"++model.binarize_mask_from_pts_for_mem_enc=true",
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
"++model.fill_hole_area=8",
]
hydra_overrides.extend(hydra_overrides_extra)
# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides)
OmegaConf.resolve(cfg)
model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path)
model = model.to(device)
if mode == "eval":
model.eval()
return model
def _load_checkpoint(model, ckpt_path):
if ckpt_path is not None:
sd = torch.load(ckpt_path, map_location="cpu")["model"]
missing_keys, unexpected_keys = model.load_state_dict(sd)
if missing_keys:
logging.error(missing_keys)
raise RuntimeError()
if unexpected_keys:
logging.error(unexpected_keys)
raise RuntimeError()
logging.info("Loaded checkpoint sucessfully")
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
// adapted from https://github.com/zsef123/Connected_components_PyTorch
// with license found in the LICENSE_cctorch file in the root directory.
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <torch/script.h>
#include <vector>
// 2d
#define BLOCK_ROWS 16
#define BLOCK_COLS 16
namespace cc2d {
template <typename T>
__device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
return (bitmap >> pos) & 1;
}
__device__ int32_t find(const int32_t* s_buf, int32_t n) {
while (s_buf[n] != n)
n = s_buf[n];
return n;
}
__device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
const int32_t id = n;
while (s_buf[n] != n) {
n = s_buf[n];
s_buf[id] = n;
}
return n;
}
__device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
bool done;
do {
a = find(s_buf, a);
b = find(s_buf, b);
if (a < b) {
int32_t old = atomicMin(s_buf + b, a);
done = (old == b);
b = old;
} else if (b < a) {
int32_t old = atomicMin(s_buf + a, b);
done = (old == a);
a = old;
} else
done = true;
} while (!done);
}
__global__ void
init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
const uint32_t idx = row * W + col;
if (row < H && col < W)
label[idx] = idx;
}
__global__ void
merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
const uint32_t idx = row * W + col;
if (row >= H || col >= W)
return;
uint32_t P = 0;
if (img[idx])
P |= 0x777;
if (row + 1 < H && img[idx + W])
P |= 0x777 << 4;
if (col + 1 < W && img[idx + 1])
P |= 0x777 << 1;
if (col == 0)
P &= 0xEEEE;
if (col + 1 >= W)
P &= 0x3333;
else if (col + 2 >= W)
P &= 0x7777;
if (row == 0)
P &= 0xFFF0;
if (row + 1 >= H)
P &= 0xFF;
if (P > 0) {
// If need check about top-left pixel(if flag the first bit) and hit the
// top-left pixel
if (hasBit(P, 0) && img[idx - W - 1]) {
union_(label, idx, idx - 2 * W - 2); // top left block
}
if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
union_(label, idx, idx - 2 * W); // top bottom block
if (hasBit(P, 3) && img[idx + 2 - W])
union_(label, idx, idx - 2 * W + 2); // top right block
if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
union_(label, idx, idx - 2); // just left block
}
}
__global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
const uint32_t idx = row * W + col;
if (row < H && col < W)
find_n_compress(label, idx);
}
__global__ void final_labeling(
const uint8_t* img,
int32_t* label,
const int32_t W,
const int32_t H) {
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
const uint32_t idx = row * W + col;
if (row >= H || col >= W)
return;
int32_t y = label[idx] + 1;
if (img[idx])
label[idx] = y;
else
label[idx] = 0;
if (col + 1 < W) {
if (img[idx + 1])
label[idx + 1] = y;
else
label[idx + 1] = 0;
if (row + 1 < H) {
if (img[idx + W + 1])
label[idx + W + 1] = y;
else
label[idx + W + 1] = 0;
}
}
if (row + 1 < H) {
if (img[idx + W])
label[idx + W] = y;
else
label[idx + W] = 0;
}
}
__global__ void init_counting(
const int32_t* label,
int32_t* count_init,
const int32_t W,
const int32_t H) {
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
const uint32_t idx = row * W + col;
if (row >= H || col >= W)
return;
int32_t y = label[idx];
if (y > 0) {
int32_t count_idx = y - 1;
atomicAdd(count_init + count_idx, 1);
}
}
__global__ void final_counting(
const int32_t* label,
const int32_t* count_init,
int32_t* count_final,
const int32_t W,
const int32_t H) {
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
const uint32_t idx = row * W + col;
if (row >= H || col >= W)
return;
int32_t y = label[idx];
if (y > 0) {
int32_t count_idx = y - 1;
count_final[idx] = count_init[count_idx];
} else {
count_final[idx] = 0;
}
}
} // namespace cc2d
std::vector<torch::Tensor> get_connected_componnets(
const torch::Tensor& inputs) {
AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
AT_ASSERTM(
inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
const uint32_t N = inputs.size(0);
const uint32_t C = inputs.size(1);
const uint32_t H = inputs.size(2);
const uint32_t W = inputs.size(3);
AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
AT_ASSERTM((H % 2) == 0, "height must be an even number");
AT_ASSERTM((W % 2) == 0, "width must be an even number");
// label must be uint32_t
auto label_options =
torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
dim3 grid = dim3(
((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
dim3 grid_count =
dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
for (int n = 0; n < N; n++) {
uint32_t offset = n * H * W;
cc2d::init_labeling<<<grid, block, 0, stream>>>(
labels.data_ptr<int32_t>() + offset, W, H);
cc2d::merge<<<grid, block, 0, stream>>>(
inputs.data_ptr<uint8_t>() + offset,
labels.data_ptr<int32_t>() + offset,
W,
H);
cc2d::compression<<<grid, block, 0, stream>>>(
labels.data_ptr<int32_t>() + offset, W, H);
cc2d::final_labeling<<<grid, block, 0, stream>>>(
inputs.data_ptr<uint8_t>() + offset,
labels.data_ptr<int32_t>() + offset,
W,
H);
// get the counting of each pixel
cc2d::init_counting<<<grid_count, block_count, 0, stream>>>(
labels.data_ptr<int32_t>() + offset,
counts_init.data_ptr<int32_t>() + offset,
W,
H);
cc2d::final_counting<<<grid_count, block_count, 0, stream>>>(
labels.data_ptr<int32_t>() + offset,
counts_init.data_ptr<int32_t>() + offset,
counts_final.data_ptr<int32_t>() + offset,
W,
H);
}
// returned values are [labels, counts]
std::vector<torch::Tensor> outputs;
outputs.push_back(labels);
outputs.push_back(counts_final);
return outputs;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"get_connected_componnets",
&get_connected_componnets,
"get_connected_componnets");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment