Commit ef30d662 authored by bailuo's avatar bailuo
Browse files

init

parents
Pipeline #2496 failed with stages
in 0 seconds
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import torch
from tqdm import tqdm
from third_parts.sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
from third_parts.sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
class SAM2VideoPredictor(SAM2Base):
"""The predictor class to handle user interactions and manage inference states."""
def __init__(
self,
fill_hole_area=0,
# whether to apply non-overlapping constraints on the output object masks
non_overlap_masks=False,
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
clear_non_cond_mem_around_input=False,
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
clear_non_cond_mem_for_multi_obj=False,
**kwargs,
):
super().__init__(**kwargs)
self.fill_hole_area = fill_hole_area
self.non_overlap_masks = non_overlap_masks
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
@torch.inference_mode()
def init_state(
self,
video_path,
offload_video_to_cpu=False,
offload_state_to_cpu=False,
async_loading_frames=False,
):
"""Initialize a inference state."""
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
)
inference_state = {}
inference_state["images"] = images
inference_state["num_frames"] = len(images)
# whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
# whether to offload the inference state to CPU memory
# turning on this option saves the GPU memory at the cost of a lower tracking fps
# (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
# and from 24 to 21 when tracking two objects)
inference_state["offload_state_to_cpu"] = offload_state_to_cpu
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = torch.device("cuda")
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = torch.device("cuda")
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
# visual features on a small number of recently visited frames for quick interactions
inference_state["cached_features"] = {}
# values that don't change across frames (so we only need to hold one copy of them)
inference_state["constants"] = {}
# mapping between client-side object id and model-side object index
inference_state["obj_id_to_idx"] = OrderedDict()
inference_state["obj_idx_to_id"] = OrderedDict()
inference_state["obj_ids"] = []
# A storage to hold the model's tracking results and states on each frame
inference_state["output_dict"] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
inference_state["output_dict_per_obj"] = {}
# A temporary storage to hold new outputs when user interact with a frame
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
inference_state["temp_output_dict_per_obj"] = {}
# Frames that already holds consolidated outputs from click or mask inputs
# (we directly use their consolidated outputs during tracking)
inference_state["consolidated_frame_inds"] = {
"cond_frame_outputs": set(), # set containing frame indices
"non_cond_frame_outputs": set(), # set containing frame indices
}
# metadata for each tracking frame (e.g. which direction it's tracked)
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"] = {}
# Warm up the visual backbone and cache the image feature on frame 0
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state
def _obj_id_to_idx(self, inference_state, obj_id):
"""Map client-side object id to model-side object index."""
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
if obj_idx is not None:
return obj_idx
# This is a new object id not sent to the server before. We only allow adding
# new objects *before* the tracking starts.
allow_new_object = not inference_state["tracking_has_started"]
if allow_new_object:
# get the next object slot
obj_idx = len(inference_state["obj_id_to_idx"])
inference_state["obj_id_to_idx"][obj_id] = obj_idx
inference_state["obj_idx_to_id"][obj_idx] = obj_id
inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
# set up input and output structures for this object
inference_state["point_inputs_per_obj"][obj_idx] = {}
inference_state["mask_inputs_per_obj"][obj_idx] = {}
inference_state["output_dict_per_obj"][obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
inference_state["temp_output_dict_per_obj"][obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
return obj_idx
else:
raise RuntimeError(
f"Cannot add new object id {obj_id} after tracking starts. "
f"All existing object ids: {inference_state['obj_ids']}. "
f"Please call 'reset_state' to restart from scratch."
)
def _obj_idx_to_id(self, inference_state, obj_idx):
"""Map model-side object index to client-side object id."""
return inference_state["obj_idx_to_id"][obj_idx]
def _get_obj_num(self, inference_state):
"""Get the total number of unique object ids received so far in this session."""
return len(inference_state["obj_idx_to_id"])
@torch.inference_mode()
def add_new_points(
self,
inference_state,
frame_idx,
obj_id,
points,
labels,
clear_old_points=True,
normalize_coords=True,
):
"""Add new points to a frame."""
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
if not isinstance(points, torch.Tensor):
points = torch.tensor(points, dtype=torch.float32)
if not isinstance(labels, torch.Tensor):
labels = torch.tensor(labels, dtype=torch.int32)
if points.dim() == 2:
points = points.unsqueeze(0) # add batch dimension
if labels.dim() == 1:
labels = labels.unsqueeze(0) # add batch dimension
if normalize_coords:
video_H = inference_state["video_height"]
video_W = inference_state["video_width"]
points = points / torch.tensor([video_W, video_H]).to(points.device)
# scale the (normalized) coordinates by the model's internal image size
points = points * self.image_size
points = points.to(inference_state["device"])
labels = labels.to(inference_state["device"])
if not clear_old_points:
point_inputs = point_inputs_per_frame.get(frame_idx, None)
else:
point_inputs = None
point_inputs = concat_points(point_inputs, points, labels)
point_inputs_per_frame[frame_idx] = point_inputs
mask_inputs_per_frame.pop(frame_idx, None)
# If this frame hasn't been tracked before, we treat it as an initial conditioning
# frame, meaning that the inputs points are to generate segments on this frame without
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
# the input points will be used to correct the already tracked masks.
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
# whether to track in reverse time order
if is_init_cond_frame:
reverse = False
else:
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# Add a frame to conditioning output if it's an initial conditioning frame or
# if the model sees all frames receiving clicks/mask as conditioning frames.
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Get any previously predicted mask logits on this object and feed it along with
# the new clicks into the SAM mask decoder.
prev_sam_mask_logits = None
# lookup temporary output dict first, which contains the most recent output
# (if not found, then lookup conditioning and non-conditioning frame output)
prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
if prev_out is None:
prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
if prev_out is None:
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
if prev_out is not None and prev_out["pred_masks"] is not None:
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, _ = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=obj_output_dict, # run on the slice of a single object
frame_idx=frame_idx,
batch_size=1, # run on the slice of a single object
is_init_cond_frame=is_init_cond_frame,
point_inputs=point_inputs,
mask_inputs=None,
reverse=reverse,
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
# allows us to enforce non-overlapping constraints on all objects before encoding
# them into memory.
run_mem_encoder=False,
prev_sam_mask_logits=prev_sam_mask_logits,
)
# Add the output to the output dict (to be used as future memory)
obj_temp_output_dict[storage_key][frame_idx] = current_out
# Resize the output mask to the original video resolution
obj_ids = inference_state["obj_ids"]
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, consolidated_out["pred_masks_video_res"]
)
return frame_idx, obj_ids, video_res_masks
@torch.inference_mode()
def add_new_mask(
self,
inference_state,
frame_idx,
obj_id,
mask,
):
"""Add new mask to a frame."""
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
if not isinstance(mask, torch.Tensor):
mask = torch.tensor(mask, dtype=torch.bool)
assert mask.dim() == 2
mask_H, mask_W = mask.shape
mask_inputs_orig = mask[None, None] # add batch and channel dimension
mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
# resize the mask if it doesn't match the model's image size
if mask_H != self.image_size or mask_W != self.image_size:
mask_inputs = torch.nn.functional.interpolate(
mask_inputs_orig,
size=(self.image_size, self.image_size),
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
)
mask_inputs = (mask_inputs >= 0.5).float()
else:
mask_inputs = mask_inputs_orig
mask_inputs_per_frame[frame_idx] = mask_inputs
point_inputs_per_frame.pop(frame_idx, None)
# If this frame hasn't been tracked before, we treat it as an initial conditioning
# frame, meaning that the inputs points are to generate segments on this frame without
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
# the input points will be used to correct the already tracked masks.
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
# whether to track in reverse time order
if is_init_cond_frame:
reverse = False
else:
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# Add a frame to conditioning output if it's an initial conditioning frame or
# if the model sees all frames receiving clicks/mask as conditioning frames.
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
current_out, _ = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=obj_output_dict, # run on the slice of a single object
frame_idx=frame_idx,
batch_size=1, # run on the slice of a single object
is_init_cond_frame=is_init_cond_frame,
point_inputs=None,
mask_inputs=mask_inputs,
reverse=reverse,
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
# allows us to enforce non-overlapping constraints on all objects before encoding
# them into memory.
run_mem_encoder=False,
)
# Add the output to the output dict (to be used as future memory)
obj_temp_output_dict[storage_key][frame_idx] = current_out
# Resize the output mask to the original video resolution
obj_ids = inference_state["obj_ids"]
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, consolidated_out["pred_masks_video_res"]
)
return frame_idx, obj_ids, video_res_masks
def _get_orig_video_res_output(self, inference_state, any_res_masks):
"""
Resize the object scores to the original video resolution (video_res_masks)
and apply non-overlapping constraints for final output.
"""
device = inference_state["device"]
video_H = inference_state["video_height"]
video_W = inference_state["video_width"]
any_res_masks = any_res_masks.to(device, non_blocking=True)
if any_res_masks.shape[-2:] == (video_H, video_W):
video_res_masks = any_res_masks
else:
video_res_masks = torch.nn.functional.interpolate(
any_res_masks,
size=(video_H, video_W),
mode="bilinear",
align_corners=False,
)
if self.non_overlap_masks:
video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
return any_res_masks, video_res_masks
def _consolidate_temp_output_across_obj(
self,
inference_state,
frame_idx,
is_cond,
run_mem_encoder,
consolidate_at_video_res=False,
):
"""
Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
a frame into a single output for all objects, including
1) fill any missing objects either from `output_dict_per_obj` (if they exist in
`output_dict_per_obj` for this frame) or leave them as placeholder values
(if they don't exist in `output_dict_per_obj` for this frame);
2) if specified, rerun memory encoder after apply non-overlapping constraints
on the object scores.
"""
batch_size = self._get_obj_num(inference_state)
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Optionally, we allow consolidating the temporary outputs at the original
# video resolution (to provide a better editing experience for mask prompts).
if consolidate_at_video_res:
assert not run_mem_encoder, "memory encoder cannot run at video resolution"
consolidated_H = inference_state["video_height"]
consolidated_W = inference_state["video_width"]
consolidated_mask_key = "pred_masks_video_res"
else:
consolidated_H = consolidated_W = self.image_size // 4
consolidated_mask_key = "pred_masks"
# Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
# will be added when rerunning the memory encoder after applying non-overlapping
# constraints to object scores. Its "pred_masks" are prefilled with a large
# negative value (NO_OBJ_SCORE) to represent missing objects.
consolidated_out = {
"maskmem_features": None,
"maskmem_pos_enc": None,
consolidated_mask_key: torch.full(
size=(batch_size, 1, consolidated_H, consolidated_W),
fill_value=NO_OBJ_SCORE,
dtype=torch.float32,
device=inference_state["storage_device"],
),
"obj_ptr": torch.full(
size=(batch_size, self.hidden_dim),
fill_value=NO_OBJ_SCORE,
dtype=torch.float32,
device=inference_state["device"],
),
}
empty_mask_ptr = None
for obj_idx in range(batch_size):
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
out = obj_temp_output_dict[storage_key].get(frame_idx, None)
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
# we fall back and look up its previous output in "output_dict_per_obj".
# We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
# "output_dict_per_obj" to find a previous output for this object.
if out is None:
out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
if out is None:
out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
# If the object doesn't appear in "output_dict_per_obj" either, we skip it
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
# placeholder above) and set its object pointer to be a dummy pointer.
if out is None:
# Fill in dummy object pointers for those objects without any inputs or
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
# i.e. when we need to build the memory for tracking).
if run_mem_encoder:
if empty_mask_ptr is None:
empty_mask_ptr = self._get_empty_mask_ptr(
inference_state, frame_idx
)
# fill object pointer with a dummy pointer (based on an empty mask)
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
continue
# Add the temporary object output mask to consolidated output mask
obj_mask = out["pred_masks"]
consolidated_pred_masks = consolidated_out[consolidated_mask_key]
if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
else:
# Resize first if temporary object mask has a different resolution
resized_obj_mask = torch.nn.functional.interpolate(
obj_mask,
size=consolidated_pred_masks.shape[-2:],
mode="bilinear",
align_corners=False,
)
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
# Optionally, apply non-overlapping constraints on the consolidated scores
# and rerun the memory encoder
if run_mem_encoder:
device = inference_state["device"]
high_res_masks = torch.nn.functional.interpolate(
consolidated_out["pred_masks"].to(device, non_blocking=True),
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
if self.non_overlap_masks_for_mem_enc:
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
inference_state=inference_state,
frame_idx=frame_idx,
batch_size=batch_size,
high_res_masks=high_res_masks,
is_mask_from_pts=True, # these frames are what the user interacted with
)
consolidated_out["maskmem_features"] = maskmem_features
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
return consolidated_out
def _get_empty_mask_ptr(self, inference_state, frame_idx):
"""Get a dummy object pointer based on an empty mask on the current frame."""
# A dummy (empty) mask with a single object
batch_size = 1
mask_inputs = torch.zeros(
(batch_size, 1, self.image_size, self.image_size),
dtype=torch.float32,
device=inference_state["device"],
)
# Retrieve correct image features
(
_,
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = self._get_image_feature(inference_state, frame_idx, batch_size)
# Feed the empty mask and image feature above to get a dummy object pointer
current_out = self.track_step(
frame_idx=frame_idx,
is_init_cond_frame=True,
current_vision_feats=current_vision_feats,
current_vision_pos_embeds=current_vision_pos_embeds,
feat_sizes=feat_sizes,
point_inputs=None,
mask_inputs=mask_inputs,
output_dict={},
num_frames=inference_state["num_frames"],
track_in_reverse=False,
run_mem_encoder=False,
prev_sam_mask_logits=None,
)
return current_out["obj_ptr"]
@torch.inference_mode()
def propagate_in_video_preflight(self, inference_state):
"""Prepare inference_state and consolidate temporary outputs before tracking."""
# Tracking has started and we don't allow adding new objects until session is reset.
inference_state["tracking_has_started"] = True
batch_size = self._get_obj_num(inference_state)
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
# add them into "output_dict".
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
output_dict = inference_state["output_dict"]
# "consolidated_frame_inds" contains indices of those frames where consolidated
# temporary outputs have been added (either in this call or any previous calls
# to `propagate_in_video_preflight`).
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
for is_cond in [False, True]:
# Separately consolidate conditioning and non-conditioning temp outptus
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Find all the frames that contain temporary outputs for any objects
# (these should be the frames that have just received clicks for mask inputs
# via `add_new_points` or `add_new_mask`)
temp_frame_inds = set()
for obj_temp_output_dict in temp_output_dict_per_obj.values():
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
consolidated_frame_inds[storage_key].update(temp_frame_inds)
# consolidate the temprary output across all objects on this frame
for frame_idx in temp_frame_inds:
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
)
# merge them into "output_dict" and also create per-object slices
output_dict[storage_key][frame_idx] = consolidated_out
self._add_output_per_object(
inference_state, frame_idx, consolidated_out, storage_key
)
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
if clear_non_cond_mem:
# clear non-conditioning memory of the surrounding frames
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
# clear temporary outputs in `temp_output_dict_per_obj`
for obj_temp_output_dict in temp_output_dict_per_obj.values():
obj_temp_output_dict[storage_key].clear()
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
# output on the same frame in "non_cond_frame_outputs"
for frame_idx in output_dict["cond_frame_outputs"]:
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
for obj_output_dict in inference_state["output_dict_per_obj"].values():
for frame_idx in obj_output_dict["cond_frame_outputs"]:
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
assert frame_idx in output_dict["cond_frame_outputs"]
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
# with either points or mask inputs (which should be true under a correct workflow).
all_consolidated_frame_inds = (
consolidated_frame_inds["cond_frame_outputs"]
| consolidated_frame_inds["non_cond_frame_outputs"]
)
input_frames_inds = set()
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
input_frames_inds.update(point_inputs_per_frame.keys())
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
input_frames_inds.update(mask_inputs_per_frame.keys())
assert all_consolidated_frame_inds == input_frames_inds
@torch.inference_mode()
def propagate_in_video(
self,
inference_state,
start_frame_idx=None,
max_frame_num_to_track=None,
reverse=False,
):
"""Propagate the input points across frames to track in the entire video."""
self.propagate_in_video_preflight(inference_state)
output_dict = inference_state["output_dict"]
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
obj_ids = inference_state["obj_ids"]
num_frames = inference_state["num_frames"]
batch_size = self._get_obj_num(inference_state)
if len(output_dict["cond_frame_outputs"]) == 0:
raise RuntimeError("No points are provided; please add points first")
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
# set start index, end index, and processing order
if start_frame_idx is None:
# default: start from the earliest frame with input points
start_frame_idx = min(output_dict["cond_frame_outputs"])
if max_frame_num_to_track is None:
# default: track all the frames in the video
max_frame_num_to_track = num_frames
if reverse:
end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
if start_frame_idx > 0:
processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
else:
processing_order = [] # skip reverse tracking if starting from frame 0
else:
end_frame_idx = min(
start_frame_idx + max_frame_num_to_track, num_frames - 1
)
processing_order = range(start_frame_idx, end_frame_idx + 1)
for frame_idx in tqdm(processing_order, desc="propagate in video"):
# We skip those frames already in consolidated outputs (these are frames
# that received input clicks or mask). Note that we cannot directly run
# batched forward on them via `_run_single_frame_inference` because the
# number of clicks on each object might be different.
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
storage_key = "cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
if clear_non_cond_mem:
# clear non-conditioning memory of the surrounding frames
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
storage_key = "non_cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
else:
storage_key = "non_cond_frame_outputs"
current_out, pred_masks = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=output_dict,
frame_idx=frame_idx,
batch_size=batch_size,
is_init_cond_frame=False,
point_inputs=None,
mask_inputs=None,
reverse=reverse,
run_mem_encoder=True,
)
output_dict[storage_key][frame_idx] = current_out
# Create slices of per-object outputs for subsequent interaction with each
# individual object after tracking.
self._add_output_per_object(
inference_state, frame_idx, current_out, storage_key
)
inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
# Resize the output mask to the original video resolution (we directly use
# the mask scores on GPU for output to avoid any CPU conversion in between)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, pred_masks
)
yield frame_idx, obj_ids, video_res_masks
def _add_output_per_object(
self, inference_state, frame_idx, current_out, storage_key
):
"""
Split a multi-object output into per-object output slices and add them into
`output_dict_per_obj`. The resulting slices share the same tensor storage.
"""
maskmem_features = current_out["maskmem_features"]
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
maskmem_pos_enc = current_out["maskmem_pos_enc"]
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
output_dict_per_obj = inference_state["output_dict_per_obj"]
for obj_idx, obj_output_dict in output_dict_per_obj.items():
obj_slice = slice(obj_idx, obj_idx + 1)
obj_out = {
"maskmem_features": None,
"maskmem_pos_enc": None,
"pred_masks": current_out["pred_masks"][obj_slice],
"obj_ptr": current_out["obj_ptr"][obj_slice],
}
if maskmem_features is not None:
obj_out["maskmem_features"] = maskmem_features[obj_slice]
if maskmem_pos_enc is not None:
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
obj_output_dict[storage_key][frame_idx] = obj_out
@torch.inference_mode()
def reset_state(self, inference_state):
"""Remove all input points or mask in all frames throughout the video."""
self._reset_tracking_results(inference_state)
# Remove all object ids
inference_state["obj_id_to_idx"].clear()
inference_state["obj_idx_to_id"].clear()
inference_state["obj_ids"].clear()
inference_state["point_inputs_per_obj"].clear()
inference_state["mask_inputs_per_obj"].clear()
inference_state["output_dict_per_obj"].clear()
inference_state["temp_output_dict_per_obj"].clear()
def _reset_tracking_results(self, inference_state):
"""Reset all tracking inputs and results across the videos."""
for v in inference_state["point_inputs_per_obj"].values():
v.clear()
for v in inference_state["mask_inputs_per_obj"].values():
v.clear()
for v in inference_state["output_dict_per_obj"].values():
v["cond_frame_outputs"].clear()
v["non_cond_frame_outputs"].clear()
for v in inference_state["temp_output_dict_per_obj"].values():
v["cond_frame_outputs"].clear()
v["non_cond_frame_outputs"].clear()
inference_state["output_dict"]["cond_frame_outputs"].clear()
inference_state["output_dict"]["non_cond_frame_outputs"].clear()
inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"].clear()
def _get_image_feature(self, inference_state, frame_idx, batch_size):
"""Compute the image features on a given frame."""
# Look up in the cache first
image, backbone_out = inference_state["cached_features"].get(
frame_idx, (None, None)
)
if backbone_out is None:
# Cache miss -- we will run inference on a single image
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
# expand the features to have the same dimension as the number of objects
expanded_image = image.expand(batch_size, -1, -1, -1)
expanded_backbone_out = {
"backbone_fpn": backbone_out["backbone_fpn"].copy(),
"vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
}
for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
expanded_backbone_out["backbone_fpn"][i] = feat.expand(
batch_size, -1, -1, -1
)
for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
pos = pos.expand(batch_size, -1, -1, -1)
expanded_backbone_out["vision_pos_enc"][i] = pos
features = self._prepare_backbone_features(expanded_backbone_out)
features = (expanded_image,) + features
return features
def _run_single_frame_inference(
self,
inference_state,
output_dict,
frame_idx,
batch_size,
is_init_cond_frame,
point_inputs,
mask_inputs,
reverse,
run_mem_encoder,
prev_sam_mask_logits=None,
):
"""Run tracking on a single frame based on current inputs and previous memory."""
# Retrieve correct image features
(
_,
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = self._get_image_feature(inference_state, frame_idx, batch_size)
# point and mask should not appear as input simultaneously on the same frame
assert point_inputs is None or mask_inputs is None
current_out = self.track_step(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats,
current_vision_pos_embeds=current_vision_pos_embeds,
feat_sizes=feat_sizes,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
output_dict=output_dict,
num_frames=inference_state["num_frames"],
track_in_reverse=reverse,
run_mem_encoder=run_mem_encoder,
prev_sam_mask_logits=prev_sam_mask_logits,
)
# optionally offload the output to CPU memory to save GPU space
storage_device = inference_state["storage_device"]
maskmem_features = current_out["maskmem_features"]
if maskmem_features is not None:
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
pred_masks_gpu = current_out["pred_masks"]
# potentially fill holes in the predicted masks
if self.fill_hole_area > 0:
pred_masks_gpu = fill_holes_in_mask_scores(
pred_masks_gpu, self.fill_hole_area
)
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
obj_ptr = current_out["obj_ptr"]
# make a compact version of this frame's output to reduce the state size
compact_current_out = {
"maskmem_features": maskmem_features,
"maskmem_pos_enc": maskmem_pos_enc,
"pred_masks": pred_masks,
"obj_ptr": obj_ptr,
}
return compact_current_out, pred_masks_gpu
def _run_memory_encoder(
self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts
):
"""
Run the memory encoder on `high_res_masks`. This is usually after applying
non-overlapping constraints to object scores. Since their scores changed, their
memory also need to be computed again with the memory encoder.
"""
# Retrieve correct image features
_, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
inference_state, frame_idx, batch_size
)
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
current_vision_feats=current_vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_masks,
is_mask_from_pts=is_mask_from_pts,
)
# optionally offload the output to CPU memory to save GPU space
storage_device = inference_state["storage_device"]
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
maskmem_pos_enc = self._get_maskmem_pos_enc(
inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
)
return maskmem_features, maskmem_pos_enc
def _get_maskmem_pos_enc(self, inference_state, current_out):
"""
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
a constant in the inference session to reduce session storage size.
"""
model_constants = inference_state["constants"]
# "out_maskmem_pos_enc" should be either a list of tensors or None
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
if out_maskmem_pos_enc is not None:
if "maskmem_pos_enc" not in model_constants:
assert isinstance(out_maskmem_pos_enc, list)
# only take the slice for one object, since it's same across objects
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
else:
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
# expand the cached maskmem_pos_enc to the actual batch size
batch_size = out_maskmem_pos_enc[0].size(0)
expanded_maskmem_pos_enc = [
x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
]
else:
expanded_maskmem_pos_enc = None
return expanded_maskmem_pos_enc
def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
"""
Remove the non-conditioning memory around the input frame. When users provide
correction clicks, the surrounding frames' non-conditioning memories can still
contain outdated object appearance information and could confuse the model.
This method clears those non-conditioning memories surrounding the interacted
frame to avoid giving the model both old and new information about the object.
"""
r = self.memory_temporal_stride_for_eval
frame_idx_begin = frame_idx - r * self.num_maskmem
frame_idx_end = frame_idx + r * self.num_maskmem
output_dict = inference_state["output_dict"]
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
for t in range(frame_idx_begin, frame_idx_end + 1):
non_cond_frame_outputs.pop(t, None)
for obj_output_dict in inference_state["output_dict_per_obj"].values():
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from copy import deepcopy
from itertools import product
from typing import Any, Dict, Generator, ItemsView, List, Tuple
import numpy as np
import torch
# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py
class MaskData:
"""
A structure for storing masks and their related data in batched format.
Implements basic filtering and concatenation.
"""
def __init__(self, **kwargs) -> None:
for v in kwargs.values():
assert isinstance(
v, (list, np.ndarray, torch.Tensor)
), "MaskData only supports list, numpy arrays, and torch tensors."
self._stats = dict(**kwargs)
def __setitem__(self, key: str, item: Any) -> None:
assert isinstance(
item, (list, np.ndarray, torch.Tensor)
), "MaskData only supports list, numpy arrays, and torch tensors."
self._stats[key] = item
def __delitem__(self, key: str) -> None:
del self._stats[key]
def __getitem__(self, key: str) -> Any:
return self._stats[key]
def items(self) -> ItemsView[str, Any]:
return self._stats.items()
def filter(self, keep: torch.Tensor) -> None:
for k, v in self._stats.items():
if v is None:
self._stats[k] = None
elif isinstance(v, torch.Tensor):
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
elif isinstance(v, np.ndarray):
self._stats[k] = v[keep.detach().cpu().numpy()]
elif isinstance(v, list) and keep.dtype == torch.bool:
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
elif isinstance(v, list):
self._stats[k] = [v[i] for i in keep]
else:
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
def cat(self, new_stats: "MaskData") -> None:
for k, v in new_stats.items():
if k not in self._stats or self._stats[k] is None:
self._stats[k] = deepcopy(v)
elif isinstance(v, torch.Tensor):
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
elif isinstance(v, np.ndarray):
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
elif isinstance(v, list):
self._stats[k] = self._stats[k] + deepcopy(v)
else:
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
def to_numpy(self) -> None:
for k, v in self._stats.items():
if isinstance(v, torch.Tensor):
self._stats[k] = v.float().detach().cpu().numpy()
def is_box_near_crop_edge(
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
) -> torch.Tensor:
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
return torch.any(near_crop_edge, dim=1)
def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
box_xywh = deepcopy(box_xyxy)
box_xywh[2] = box_xywh[2] - box_xywh[0]
box_xywh[3] = box_xywh[3] - box_xywh[1]
return box_xywh
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
assert len(args) > 0 and all(
len(a) == len(args[0]) for a in args
), "Batched iteration must have inputs of all the same size."
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches):
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
"""
Encodes masks to an uncompressed RLE, in the format expected by
pycoco tools.
"""
# Put in fortran order and flatten h,w
b, h, w = tensor.shape
tensor = tensor.permute(0, 2, 1).flatten(1)
# Compute change indices
diff = tensor[:, 1:] ^ tensor[:, :-1]
change_indices = diff.nonzero()
# Encode run length
out = []
for i in range(b):
cur_idxs = change_indices[change_indices[:, 0] == i, 1]
cur_idxs = torch.cat(
[
torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
cur_idxs + 1,
torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
]
)
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if tensor[i, 0] == 0 else [0]
counts.extend(btw_idxs.detach().cpu().tolist())
out.append({"size": [h, w], "counts": counts})
return out
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
"""Compute a binary mask from an uncompressed RLE."""
h, w = rle["size"]
mask = np.empty(h * w, dtype=bool)
idx = 0
parity = False
for count in rle["counts"]:
mask[idx : idx + count] = parity
idx += count
parity ^= True
mask = mask.reshape(w, h)
return mask.transpose() # Put in C order
def area_from_rle(rle: Dict[str, Any]) -> int:
return sum(rle["counts"][1::2])
def calculate_stability_score(
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
) -> torch.Tensor:
"""
Computes the stability score for a batch of masks. The stability
score is the IoU between the binary masks obtained by thresholding
the predicted mask logits at high and low values.
"""
# One mask is always contained inside the other.
# Save memory by preventing unnecessary cast to torch.int64
intersections = (
(masks > (mask_threshold + threshold_offset))
.sum(-1, dtype=torch.int16)
.sum(-1, dtype=torch.int32)
)
unions = (
(masks > (mask_threshold - threshold_offset))
.sum(-1, dtype=torch.int16)
.sum(-1, dtype=torch.int32)
)
return intersections / unions
def build_point_grid(n_per_side: int) -> np.ndarray:
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
offset = 1 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
return points
def build_all_layer_point_grids(
n_per_side: int, n_layers: int, scale_per_layer: int
) -> List[np.ndarray]:
"""Generates point grids for all crop layers."""
points_by_layer = []
for i in range(n_layers + 1):
n_points = int(n_per_side / (scale_per_layer**i))
points_by_layer.append(build_point_grid(n_points))
return points_by_layer
def generate_crop_boxes(
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
) -> Tuple[List[List[int]], List[int]]:
"""
Generates a list of crop boxes of different sizes. Each layer
has (2**i)**2 boxes for the ith layer.
"""
crop_boxes, layer_idxs = [], []
im_h, im_w = im_size
short_side = min(im_h, im_w)
# Original image
crop_boxes.append([0, 0, im_w, im_h])
layer_idxs.append(0)
def crop_len(orig_len, n_crops, overlap):
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
for i_layer in range(n_layers):
n_crops_per_side = 2 ** (i_layer + 1)
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
crop_w = crop_len(im_w, n_crops_per_side, overlap)
crop_h = crop_len(im_h, n_crops_per_side, overlap)
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
# Crops in XYWH format
for x0, y0 in product(crop_box_x0, crop_box_y0):
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
crop_boxes.append(box)
layer_idxs.append(i_layer + 1)
return crop_boxes, layer_idxs
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
# Check if boxes has a channel dimension
if len(boxes.shape) == 3:
offset = offset.unsqueeze(1)
return boxes + offset
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0]], device=points.device)
# Check if points has a channel dimension
if len(points.shape) == 3:
offset = offset.unsqueeze(1)
return points + offset
def uncrop_masks(
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
) -> torch.Tensor:
x0, y0, x1, y1 = crop_box
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
return masks
# Coordinate transform masks
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
pad = (x0, pad_x - x0, y0, pad_y - y0)
return torch.nn.functional.pad(masks, pad, value=0)
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
from pycocotools import mask as mask_utils # type: ignore
h, w = uncompressed_rle["size"]
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
return rle
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
"""
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
"""
# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
# Normalize shape to CxHxW
shape = masks.shape
h, w = shape[-2:]
if len(shape) > 2:
masks = masks.flatten(0, -3)
else:
masks = masks.unsqueeze(0)
# Get top and bottom edges
in_height, _ = torch.max(masks, dim=-1)
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
in_height_coords = in_height_coords + h * (~in_height)
top_edges, _ = torch.min(in_height_coords, dim=-1)
# Get left and right edges
in_width, _ = torch.max(masks, dim=-2)
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
right_edges, _ = torch.max(in_width_coords, dim=-1)
in_width_coords = in_width_coords + w * (~in_width)
left_edges, _ = torch.min(in_width_coords, dim=-1)
# If the mask is empty the right edge will be to the left of the left edge.
# Replace these boxes with [0, 0, 0, 0]
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
out = out * (~empty_filter).unsqueeze(-1)
# Return to original shape
if len(shape) > 2:
out = out.reshape(*shape[:-2], 4)
else:
out = out[0]
return out
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import warnings
from threading import Thread
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
def get_sdpa_settings():
if torch.cuda.is_available():
old_gpu = torch.cuda.get_device_properties(0).major < 7
# only use Flash Attention on Ampere (8.0) or newer GPUs
use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
if not use_flash_attn:
warnings.warn(
"Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
category=UserWarning,
stacklevel=2,
)
# keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
# available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
if pytorch_version < (2, 2):
warnings.warn(
f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
"Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
category=UserWarning,
stacklevel=2,
)
math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
else:
old_gpu = True
use_flash_attn = False
math_kernel_on = True
return old_gpu, use_flash_attn, math_kernel_on
def get_connected_components(mask):
"""
Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
Inputs:
- mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
background.
Outputs:
- labels: A tensor of shape (N, 1, H, W) containing the connected component labels
for foreground pixels and 0 for background pixels.
- counts: A tensor of shape (N, 1, H, W) containing the area of the connected
components for foreground pixels and 0 for background pixels.
"""
from torch.utils.cpp_extension import load
get_connected_componnets = load(
name="get_connected_componnets",
sources=["third_parts/sam2/csrc/connected_components.cu"],
verbose=True,
extra_cuda_cflags=[
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
)
return get_connected_componnets.get_connected_componnets(mask.to(torch.uint8).contiguous())
def mask_to_box(masks: torch.Tensor):
"""
compute bounding box given an input mask
Inputs:
- masks: [B, 1, H, W] boxes, dtype=torch.Tensor
Returns:
- box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
"""
B, _, h, w = masks.shape
device = masks.device
xs = torch.arange(w, device=device, dtype=torch.int32)
ys = torch.arange(h, device=device, dtype=torch.int32)
grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
return bbox_coords
def _load_img_as_tensor(img_path, image_size):
img_pil = Image.open(img_path)
img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
img_np = img_np / 255.0
else:
raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
img = torch.from_numpy(img_np).permute(2, 0, 1)
video_width, video_height = img_pil.size # the original video size
return img, video_height, video_width
class AsyncVideoFrameLoader:
"""
A list of video frames to be load asynchronously without blocking session start.
"""
def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
self.img_paths = img_paths
self.image_size = image_size
self.offload_video_to_cpu = offload_video_to_cpu
self.img_mean = img_mean
self.img_std = img_std
# items in `self._images` will be loaded asynchronously
self.images = [None] * len(img_paths)
# catch and raise any exceptions in the async loading thread
self.exception = None
# video_height and video_width be filled when loading the first image
self.video_height = None
self.video_width = None
# load the first frame to fill video_height and video_width and also
# to cache it (since it's most likely where the user will click)
self.__getitem__(0)
# load the rest of frames asynchronously without blocking the session start
def _load_frames():
try:
for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
self.__getitem__(n)
except Exception as e:
self.exception = e
self.thread = Thread(target=_load_frames, daemon=True)
self.thread.start()
def __getitem__(self, index):
if self.exception is not None:
raise RuntimeError("Failure in frame loading thread") from self.exception
img = self.images[index]
if img is not None:
return img
img, video_height, video_width = _load_img_as_tensor(
self.img_paths[index], self.image_size
)
self.video_height = video_height
self.video_width = video_width
# normalize by mean and std
img -= self.img_mean
img /= self.img_std
if not self.offload_video_to_cpu:
img = img.cuda(non_blocking=True)
self.images[index] = img
return img
def __len__(self):
return len(self.images)
def load_video_frames(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
):
"""
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
The frames are resized to image_size x image_size and are loaded to GPU if
`offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
You can load a frame asynchronously by setting `async_loading_frames` to `True`.
"""
if isinstance(video_path, str) and os.path.isdir(video_path):
jpg_folder = video_path
else:
raise NotImplementedError("Only JPEG frames are supported at this moment")
frame_names = [
p
for p in os.listdir(jpg_folder)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
num_frames = len(frame_names)
if num_frames == 0:
raise RuntimeError(f"no images found in {jpg_folder}")
img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
if async_loading_frames:
lazy_images = AsyncVideoFrameLoader(
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
)
return lazy_images, lazy_images.video_height, lazy_images.video_width
images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
# normalize by mean and std
images -= img_mean
images /= img_std
return images, video_height, video_width
def fill_holes_in_mask_scores(mask, max_area):
"""
A post processor to fill small holes in mask scores with area under `max_area`.
"""
# Holes are those connected components in background with area <= self.max_area
# (background regions are those with mask scores <= 0)
assert max_area > 0, "max_area must be positive"
labels, areas = get_connected_components(mask <= 0)
is_hole = (labels > 0) & (areas <= max_area)
# We fill holes with a small positive mask score (0.1) to change them to foreground.
mask = torch.where(is_hole, 0.1, mask)
return mask
def concat_points(old_point_inputs, new_points, new_labels):
"""Add new points and labels to previous point inputs (add at the end)."""
if old_point_inputs is None:
points, labels = new_points, new_labels
else:
points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
return {"point_coords": points, "point_labels": labels}
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Normalize, Resize, ToTensor
class SAM2Transforms(nn.Module):
def __init__(
self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
):
"""
Transforms for SAM2.
"""
super().__init__()
self.resolution = resolution
self.mask_threshold = mask_threshold
self.max_hole_area = max_hole_area
self.max_sprinkle_area = max_sprinkle_area
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
self.to_tensor = ToTensor()
self.transforms = torch.jit.script(
nn.Sequential(
Resize((self.resolution, self.resolution)),
Normalize(self.mean, self.std),
)
)
def __call__(self, x):
x = self.to_tensor(x)
return self.transforms(x)
def forward_batch(self, img_list):
img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
img_batch = torch.stack(img_batch, dim=0)
return img_batch
def transform_coords(
self, coords: torch.Tensor, normalize=False, orig_hw=None
) -> torch.Tensor:
"""
Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
Returns
Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
"""
if normalize:
assert orig_hw is not None
h, w = orig_hw
coords = coords.clone()
coords[..., 0] = coords[..., 0] / w
coords[..., 1] = coords[..., 1] / h
coords = coords * self.resolution # unnormalize coords
return coords
def transform_boxes(
self, boxes: torch.Tensor, normalize=False, orig_hw=None
) -> torch.Tensor:
"""
Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
"""
boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
return boxes
def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
"""
Perform PostProcessing on output masks.
"""
from third_parts.sam2.utils.misc import get_connected_components
masks = masks.float()
if self.max_hole_area > 0:
# Holes are those connected components in background with area <= self.fill_hole_area
# (background regions are those with mask scores <= self.mask_threshold)
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
labels, areas = get_connected_components(mask_flat <= self.mask_threshold)
is_hole = (labels > 0) & (areas <= self.max_hole_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with a small positive mask score (10.0) to change them to foreground.
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
if self.max_sprinkle_area > 0:
labels, areas = get_connected_components(mask_flat > self.mask_threshold)
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with negative mask score (-10.0) to change them to background.
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
return masks
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from collections import OrderedDict
import cv2
from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
CAP_PROP_POS_FRAMES)
from mmengine.utils import (check_file_exist, mkdir_or_exist, track_progress)
class Cache:
def __init__(self, capacity):
self._cache = OrderedDict()
self._capacity = int(capacity)
if capacity <= 0:
raise ValueError('capacity must be a positive integer')
@property
def capacity(self):
return self._capacity
@property
def size(self):
return len(self._cache)
def put(self, key, val):
if key in self._cache:
return
if len(self._cache) >= self.capacity:
self._cache.popitem(last=False)
self._cache[key] = val
def get(self, key, default=None):
val = self._cache[key] if key in self._cache else default
return val
class VideoReader:
"""Video class with similar usage to a list object.
This video wrapper class provides convenient apis to access frames.
There exists an issue of OpenCV's VideoCapture class that jumping to a
certain frame may be inaccurate. It is fixed in this class by checking
the position after jumping each time.
Cache is used when decoding videos. So if the same frame is visited for
the second time, there is no need to decode again if it is stored in the
cache.
Examples:
>>> import mmcv
>>> v = mmcv.VideoReader('sample.mp4')
>>> len(v) # get the total frame number with `len()`
120
>>> for img in v: # v is iterable
>>> mmcv.imshow(img)
>>> v[5] # get the 6th frame
"""
def __init__(self, filename, cache_capacity=10):
# Check whether the video path is a url
if not filename.startswith(('https://', 'http://')):
check_file_exist(filename, 'Video file not found: ' + filename)
self._vcap = cv2.VideoCapture(filename)
assert cache_capacity > 0
self._cache = Cache(cache_capacity)
self._position = 0
# get basic info
self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
self._fps = self._vcap.get(CAP_PROP_FPS)
self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
@property
def vcap(self):
""":obj:`cv2.VideoCapture`: The raw VideoCapture object."""
return self._vcap
@property
def opened(self):
"""bool: Indicate whether the video is opened."""
return self._vcap.isOpened()
@property
def width(self):
"""int: Width of video frames."""
return self._width
@property
def height(self):
"""int: Height of video frames."""
return self._height
@property
def resolution(self):
"""tuple: Video resolution (width, height)."""
return (self._width, self._height)
@property
def fps(self):
"""float: FPS of the video."""
return self._fps
@property
def frame_cnt(self):
"""int: Total frames of the video."""
return self._frame_cnt
@property
def fourcc(self):
"""str: "Four character code" of the video."""
return self._fourcc
@property
def position(self):
"""int: Current cursor position, indicating frame decoded."""
return self._position
def _get_real_position(self):
return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
def _set_real_position(self, frame_id):
self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
pos = self._get_real_position()
for _ in range(frame_id - pos):
self._vcap.read()
self._position = frame_id
def read(self):
"""Read the next frame.
If the next frame have been decoded before and in the cache, then
return it directly, otherwise decode, cache and return it.
Returns:
ndarray or None: Return the frame if successful, otherwise None.
"""
# pos = self._position
if self._cache:
img = self._cache.get(self._position)
if img is not None:
ret = True
else:
if self._position != self._get_real_position():
self._set_real_position(self._position)
ret, img = self._vcap.read()
if ret:
self._cache.put(self._position, img)
else:
ret, img = self._vcap.read()
if ret:
self._position += 1
return img
def get_frame(self, frame_id):
"""Get frame by index.
Args:
frame_id (int): Index of the expected frame, 0-based.
Returns:
ndarray or None: Return the frame if successful, otherwise None.
"""
if frame_id < 0 or frame_id >= self._frame_cnt:
raise IndexError(
f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
if frame_id == self._position:
return self.read()
if self._cache:
img = self._cache.get(frame_id)
if img is not None:
self._position = frame_id + 1
return img
self._set_real_position(frame_id)
ret, img = self._vcap.read()
if ret:
if self._cache:
self._cache.put(self._position, img)
self._position += 1
return img
def current_frame(self):
"""Get the current frame (frame that is just visited).
Returns:
ndarray or None: If the video is fresh, return None, otherwise
return the frame.
"""
if self._position == 0:
return None
return self._cache.get(self._position - 1)
def cvt2frames(self,
frame_dir,
file_start=0,
filename_tmpl='{:06d}.jpg',
start=0,
max_num=0,
show_progress=True):
"""Convert a video to frame images.
Args:
frame_dir (str): Output directory to store all the frame images.
file_start (int): Filenames will start from the specified number.
filename_tmpl (str): Filename template with the index as the
placeholder.
start (int): The starting frame index.
max_num (int): Maximum number of frames to be written.
show_progress (bool): Whether to show a progress bar.
"""
mkdir_or_exist(frame_dir)
if max_num == 0:
task_num = self.frame_cnt - start
else:
task_num = min(self.frame_cnt - start, max_num)
if task_num <= 0:
raise ValueError('start must be less than total frame number')
if start > 0:
self._set_real_position(start)
def write_frame(file_idx):
img = self.read()
if img is None:
return
filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
cv2.imwrite(filename, img)
if show_progress:
track_progress(write_frame, range(file_start,
file_start + task_num))
else:
for i in range(task_num):
write_frame(file_start + i)
def __len__(self):
return self.frame_cnt
def __getitem__(self, index):
if isinstance(index, slice):
return [
self.get_frame(i)
for i in range(*index.indices(self.frame_cnt))
]
# support negative indexing
if index < 0:
index += self.frame_cnt
if index < 0:
raise IndexError('index out of range')
return self.get_frame(index)
def __iter__(self):
self._set_real_position(0)
return self
def __next__(self):
img = self.read()
if img is not None:
return img
else:
raise StopIteration
next = __next__
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self._vcap.release()
#!/usr/bin/env bash
set -x
FILE=$1
CONFIG=$2
GPUS=$3
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-$((28500 + $RANDOM % 2000))}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
DEEPSPEED=${DEEPSPEED:-deepspeed_zero2}
if command -v torchrun &> /dev/null
then
echo "Using torchrun mode."
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \
torchrun --nnodes=${NNODES} \
--nnodes=${NNODES} \
--node_rank=${NODE_RANK} \
--master_addr=${MASTER_ADDR} \
--master_port=${PORT} \
--nproc_per_node=${GPUS} \
tools/${FILE}.py ${CONFIG} --launcher pytorch --deepspeed $DEEPSPEED "${@:4}"
else
echo "Using launch mode."
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \
python -m torch.distributed.launch \
--nnodes=${NNODES} \
--node_rank=${NODE_RANK} \
--master_addr=${MASTER_ADDR} \
--master_port=${PORT} \
--nproc_per_node=${GPUS} \
tools/${FILE}.py ${CONFIG} --launcher pytorch --deepspeed $DEEPSPEED "${@:4}"
fi
#!/usr/bin/env bash
set -x
FILE=$1
CONFIG=$2
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
MASTER_PORT=${MASTER_PORT:-$((28500 + $RANDOM % 2000))}
PARTITION=${PARTITION:-DUMMY}
JOB_NAME=${JOB_NAME:-DUMMY}
QUOTATYPE=${QUOTATYPE:-auto}
SRUN_ARGS=${SRUN_ARGS:-""}
DEEPSPEED=${DEEPSPEED:-deepspeed_zero2}
PY_ARGS=${@:3}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \
CUDA_HOME=${CONDA_PREFIX} \
LD_LIBRARY_PATH=${CONDA_PREFIX}/lib:$(realpath ~/.local/lib) \
MASTER_PORT=$MASTER_PORT \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
--quotatype=${QUOTATYPE} \
${SRUN_ARGS} \
python -u tools/${FILE}.py ${CONFIG} --launcher="slurm" --deepspeed $DEEPSPEED ${PY_ARGS}
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
import os
import os.path as osp
from types import FunctionType
from mmengine import print_log
from mmengine.config import Config, DictAction
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
from xtuner.configs import cfgs_name_path
from xtuner.model.utils import guess_load_checkpoint
from xtuner.registry import MAP_FUNC
from mmengine.model import is_model_wrapper
def parse_args():
parser = argparse.ArgumentParser(description='Test model')
parser.add_argument('config', help='config file name or path.')
parser.add_argument('--checkpoint', default=None, help='checkpoint file')
parser.add_argument(
'--work-dir',
help='the directory to save the file containing evaluation metrics')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--deepspeed',
default=None,
help='Dummy option'
)
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def register_function(cfg_dict):
if isinstance(cfg_dict, dict):
for key, value in dict.items(cfg_dict):
if isinstance(value, FunctionType):
value_str = str(value)
if value_str not in MAP_FUNC:
MAP_FUNC.register_module(module=value, name=value_str)
cfg_dict[key] = value_str
else:
register_function(value)
elif isinstance(cfg_dict, (list, tuple)):
for value in cfg_dict:
register_function(value)
def main():
args = parse_args()
if args.deepspeed is not None:
print_log("Deepspeed is not adopted during inference, Skipped.", level=logging.WARN)
# parse config
if not osp.isfile(args.config):
try:
args.config = cfgs_name_path[args.config]
except KeyError:
raise FileNotFoundError(f'Cannot find {args.config}')
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# register FunctionType object in cfg to `MAP_FUNC` Registry and
# change these FunctionType object to str
register_function(cfg._cfg_dict)
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
# build the runner from config
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
if args.checkpoint is not None:
state_dict = guess_load_checkpoint(args.checkpoint)
if is_model_wrapper(runner.model):
runner.model.module.load_state_dict(state_dict, strict=False)
else:
runner.model.load_state_dict(state_dict, strict=False)
runner.logger.info(f'Load checkpoint from {args.checkpoint}')
else:
Warning("No checkpoint !!!")
# start testing
runner.test()
if __name__ == '__main__':
main()
from xtuner.tools.train import main as train
try:
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
except:
pass
if __name__ == '__main__':
train()
from mmengine.hooks import Hook
from xtuner.registry import BUILDER
class SpecialDatasetInfoHook(Hook):
def __init__(self, tokenizer, is_intern_repo_dataset=False, special_tokens=None):
self.tokenizer = BUILDER.build(tokenizer)
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.is_intern_repo_dataset = is_intern_repo_dataset
def log(self, runner, dataset, mode='train'):
def _log(input_ids, log_prefix=''):
if self.is_intern_repo_dataset:
input_ids = [abs(x) for x in input_ids]
text = self.tokenizer.decode(input_ids)
runner.logger.info(text)
runner.logger.info(f'Num {mode} samples {len(dataset)}')
runner.logger.info(f'{mode} example:')
if 'chosen_ids' in dataset[0]:
_log(dataset[0]['chosen_ids'], log_prefix='chosen: ')
_log(dataset[0]['rejected_ids'], log_prefix='rejected: ')
else:
_log(dataset[0]['input_ids'])
def before_train(self, runner) -> None:
do_train = runner.train_loop is not None
do_eval = runner.val_loop is not None
if do_train:
train_dataset = runner.train_dataloader.dataset
self.log(runner, train_dataset, mode='train')
if do_eval:
eval_dataset = runner.val_dataloader.dataset
self.log(runner, eval_dataset, mode='eval')
def before_val(self, runner) -> None:
eval_dataset = runner.val_dataloader.dataset
self.log(runner, eval_dataset, mode='eval')
def before_test(self, runner) -> None:
test_dataset = runner.test_dataloader.dataset
self.log(runner, test_dataset, mode='test')
from .loops import TestLoop
from .video_loops import VideoTestLoop
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from mmengine.runner import ValLoop as MMENGINE_ValLoop
from mmengine.dist import broadcast_object_list, is_main_process, get_world_size, get_rank, barrier, collect_results
import math
import torch
from mmengine.model import is_model_wrapper
from types import MethodType
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
PROMPT_TEMPLATE)
from xtuner.tools.utils import get_stop_criteria, is_cn_string
from transformers import GenerationConfig
TORCH_DTYPE_MAP = dict(
fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
class TestLoop(MMENGINE_ValLoop):
def __init__(self, runner, dataloader, evaluator=None, torch_dtype='fp16', select_metric='first') -> None:
# must be concatset
super(MMENGINE_ValLoop, self).__init__(runner, dataloader)
self._runner = runner
self.torch_dtype = torch_dtype
if torch_dtype is not None:
self.torch_dtype = TORCH_DTYPE_MAP[torch_dtype]
self.select_metric = select_metric
def run(self) -> dict:
"""Launch Test."""
self.runner.logger.info('==================== Start test loop ===================')
self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch')
if is_model_wrapper(self.runner.model):
model = self.runner.model.module
else:
model = self.runner.model
model.gradient_checkpointing_disable()
model.eval()
model.cuda()
rank = get_rank()
metrics = []
# Ensure that eta and log are displayed correctly.
current_run_total_ids = 0
for _, dataset in enumerate(self.dataloader.dataset.datasets):
if not hasattr(model, 'preparing_for_generation'):
model.preparing_for_generation = MethodType(default_preparing_for_generation, model)
print("Warning, the model do not have the preparing_for_generation() function, using the default!!!")
model.preparing_for_generation(dataset.metainfo)
# split per rank
results = []
n_samples = len(dataset)
per_rank_samples = math.ceil(n_samples / get_world_size())
per_rank_ids = range(per_rank_samples * rank,
min(n_samples, per_rank_samples * (rank + 1)))
for idx in per_rank_ids:
data_batch = dataset[idx]
self.run_iter(current_run_total_ids, data_batch, results, model)
current_run_total_ids += 1
barrier()
self.runner.logger.info('==================== Start collect results ===================')
results = collect_results(results, len(dataset))
self.runner.logger.info('========= Starting the evaluation of a data ===========')
if is_main_process():
metric = dataset.evaluate(results, self.runner.work_dir)
objects = [metric]
else:
objects = [None]
broadcast_object_list(objects)
metric = objects[0]
metrics.append(metric)
# select metrics
if self.select_metric == 'first':
metrics = metrics[0]
else:
raise NotImplementedError
self.runner.logger.info('================ Ending test loop ================')
self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test')
return metrics
@torch.no_grad()
def run_iter(self, idx, data_batch, results, model):
assert 'text_prompts' in data_batch and 'pixel_values' in data_batch and 'img_id' in data_batch
prediction = {'img_id': data_batch['img_id']}
self.runner.call_hook(
'before_test_iter', batch_idx=idx, data_batch=data_batch)
outputs = model.predict_forward(**data_batch)
prediction.update(outputs)
results.append(prediction)
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
def default_preparing_for_generation(self, metainfo):
# set stop criteria and generation configs for model
assert hasattr(self, 'tokenizer'), "The Model does not have the tokenizer!!!"
self.bot_name = 'BOT'
template = PROMPT_TEMPLATE['internlm2_chat']
self.template = template
stop_words = []
stop_words += template.get('STOP_WORDS', [])
stop_criteria = get_stop_criteria(
tokenizer=self.tokenizer, stop_words=stop_words)
self.stop_criteria = stop_criteria
default_generation_kwargs = dict(
max_new_tokens=2048,
do_sample=False,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=(
self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None
else self.tokenizer.eos_token_id
),
)
default_generation_kwargs.update(metainfo.get('generation_kwargs', {}))
self.gen_config = GenerationConfig(**default_generation_kwargs)
return
class AnnoLoop(MMENGINE_ValLoop):
def __init__(self, runner, dataloader, evaluator=None, torch_dtype='fp16', select_metric='first') -> None:
# must be concatset
super(MMENGINE_ValLoop, self).__init__(runner, dataloader)
self._runner = runner
self.torch_dtype = torch_dtype
if torch_dtype is not None:
self.torch_dtype = TORCH_DTYPE_MAP[torch_dtype]
self.select_metric = select_metric
def run(self) -> dict:
"""Launch Test."""
self.runner.logger.info('==================== Start test loop ===================')
self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch')
if is_model_wrapper(self.runner.model):
model = self.runner.model.module
else:
model = self.runner.model
model.eval()
rank = get_rank()
metrics = []
# Ensure that eta and log are displayed correctly.
current_run_total_ids = 0
for _, dataset in enumerate(self.dataloader.dataset.datasets):
# split per rank
results = []
n_samples = len(dataset)
per_rank_samples = math.ceil(n_samples / get_world_size())
per_rank_ids = range(per_rank_samples * rank,
min(n_samples, per_rank_samples * (rank + 1)))
for idx in per_rank_ids:
data_batch = dataset[idx]
self.run_iter(current_run_total_ids, data_batch, results, model)
current_run_total_ids += 1
if hasattr(model, 'save_step'):
model.save_step(last=True)
barrier()
self.runner.logger.info('==================== Start collect results ===================')
results = collect_results(results, len(dataset))
self.runner.logger.info('========= Starting the evaluation of a data ===========')
if is_main_process():
metric = dataset.evaluate(results, self.runner.work_dir)
objects = [metric]
else:
objects = [None]
broadcast_object_list(objects)
metric = objects[0]
metrics.append(metric)
# select metrics
if self.select_metric == 'first':
metrics = metrics[0]
else:
raise NotImplementedError
self.runner.logger.info('================ Ending test loop ================')
self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test')
return metrics
@torch.no_grad()
def run_iter(self, idx, data_batch, results, model):
prediction = {}
self.runner.call_hook(
'before_test_iter', batch_idx=idx, data_batch=data_batch)
outputs = model.predict_forward(**data_batch)
prediction.update(outputs)
results.append(prediction)
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
import os.path
import cv2
import mmengine
from mmengine.runner import ValLoop as MMENGINE_ValLoop
from mmengine.dist import broadcast_object_list, is_main_process, get_world_size, get_rank, barrier, collect_results
import math
import torch
from mmengine.model import is_model_wrapper
from types import MethodType
from xtuner.utils import PROMPT_TEMPLATE
from xtuner.tools.utils import get_stop_criteria
from transformers import GenerationConfig
from pycocotools import mask as _mask
from mmengine.visualization.visualizer import Visualizer
from vlm.utils import VideoReader
TORCH_DTYPE_MAP = dict(fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
VID_INTERVAL = 4
def visualize(data_batch, prediction, visualize_path='work_dirs/visualize'):
if 'video_path' in data_batch:
vid_frames = VideoReader(data_batch['video_path'])[::VID_INTERVAL]
vid_id = os.path.basename(data_batch['video_path']).split('.')[0]
text_prompts = data_batch['text_prompts']
mmengine.mkdir_or_exist(os.path.join(visualize_path, vid_id))
visualizer = Visualizer()
mmengine.mkdir_or_exist(os.path.join(visualize_path, vid_id, "vid"))
for id_frame, img in enumerate(vid_frames):
out_path = os.path.join(visualize_path, vid_id, "vid", "{:06d}.jpg".format(id_frame))
cv2.imwrite(out_path, img)
for id_text, text in enumerate(text_prompts):
mmengine.mkdir_or_exist(os.path.join(visualize_path, vid_id, "sample_{:06d}".format(id_text)))
mmengine.put_text(text, os.path.join(visualize_path, vid_id, "sample_{:06d}".format(id_text), 'text.txt'))
for id_frame, img in enumerate(vid_frames):
visualizer.set_image(img)
mask = prediction['prediction_masks'][id_text][id_frame]
mask = _mask.decode(mask).astype(bool)
visualizer.draw_binary_masks(mask, colors='g')
visual_result = visualizer.get_image()
out_path = os.path.join(visualize_path, vid_id, "sample_{:06d}".format(id_text),
"{:06d}.jpg".format(id_frame))
cv2.imwrite(out_path, visual_result)
else:
images_files = data_batch['images']
vid_id = data_batch['video_id']
text_prompts = data_batch['text_prompts']
image_folder = data_batch['image_folder']
mmengine.mkdir_or_exist(os.path.join(visualize_path, "{:06d}".format(vid_id)))
visualizer = Visualizer()
mmengine.mkdir_or_exist(os.path.join(visualize_path, "{:06d}".format(vid_id), "vid"))
for id_frame, img_file in enumerate(images_files):
img = cv2.imread(os.path.join(image_folder, img_file))
out_path = os.path.join(visualize_path, "{:06d}".format(vid_id), "vid", os.path.basename(img_file))
cv2.imwrite(out_path, img)
for id_text, text in enumerate(text_prompts):
mmengine.mkdir_or_exist(os.path.join(visualize_path, "{:06d}".format(vid_id), "sample_{:06d}".format(id_text)))
mmengine.put_text(text, os.path.join(visualize_path, "{:06d}".format(vid_id), "sample_{:06d}".format(id_text),
'text.txt'))
for id_frame, img_file in enumerate(images_files):
img = cv2.imread(os.path.join(image_folder, img_file))
visualizer.set_image(img)
mask = prediction['prediction_masks'][id_text][id_frame]
mask = _mask.decode(mask).astype(bool)
visualizer.draw_binary_masks(mask, colors='g')
visual_result = visualizer.get_image()
out_path = os.path.join(visualize_path, "{:06d}".format(vid_id), "sample_{:06d}".format(id_text),
os.path.basename(img_file))
cv2.imwrite(out_path, visual_result)
class VideoTestLoop(MMENGINE_ValLoop):
def __init__(self, runner, dataloader, torch_dtype='fp16', select_metric='first', visualize=None, evaluator=None) -> None:
# must be concatset
super(MMENGINE_ValLoop, self).__init__(runner, dataloader)
self._runner = runner
self.torch_dtype = torch_dtype
if torch_dtype is not None:
self.torch_dtype = TORCH_DTYPE_MAP[torch_dtype]
self.select_metric = select_metric
self.visualize = visualize
self.evaluator = evaluator
def run(self) -> dict:
"""Launch Test."""
self.runner.logger.info('==================== Start test loop ===================')
self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch')
if is_model_wrapper(self.runner.model):
model = self.runner.model.module
else:
model = self.runner.model
model.gradient_checkpointing_disable()
model.eval()
model.cuda()
rank = get_rank()
metrics = []
# Ensure that eta and log are displayed correctly.
current_run_total_ids = 0
for _, dataset in enumerate(self.dataloader.dataset.datasets):
if not hasattr(model, 'preparing_for_generation'):
model.preparing_for_generation = MethodType(default_preparing_for_generation, model)
print("Warning, the model do not have the preparing_for_generation() function, using the default!!!")
model.preparing_for_generation(dataset.metainfo)
# split per rank
results = []
n_samples = len(dataset)
per_rank_samples = math.ceil(n_samples / get_world_size())
running_tot = per_rank_samples * get_world_size()
assert running_tot >= n_samples
per_rank_ids = range(per_rank_samples * rank, per_rank_samples * (rank + 1))
for idx in per_rank_ids:
if n_samples <= idx:
data_batch = dataset[n_samples - 1]
else:
data_batch = dataset[idx]
self.run_iter(current_run_total_ids, data_batch, results, model)
current_run_total_ids += 1
barrier()
self.runner.logger.info('==================== Start collect results ===================')
results = collect_results(results, n_samples)
self.runner.logger.info('========= Starting the evaluation of a data ===========')
if is_main_process():
metric = dataset.evaluate(results, self.runner.work_dir)
objects = [metric]
else:
objects = [None]
broadcast_object_list(objects)
metric = objects[0]
metrics.append(metric)
# select metrics
if self.select_metric == 'first':
metrics = metrics[0]
else:
raise NotImplementedError
self.runner.logger.info('================ Ending test loop ================')
self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test')
return metrics
@torch.no_grad()
def run_iter(self, idx, data_batch, results, model):
prediction = {'video_id': data_batch['video_id']}
self.runner.call_hook(
'before_test_iter', batch_idx=idx, data_batch=data_batch)
outputs = model.predict_forward(**data_batch)
prediction.update(outputs)
results.append(prediction)
if self.visualize:
# if not prediction['is_exists'][0].all():
# print(prediction['is_exists'])
visualize(data_batch=data_batch, prediction=prediction, visualize_path=self.visualize)
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
def default_preparing_for_generation(self, metainfo):
# set stop criteria and generation configs for model
assert hasattr(self, 'tokenizer'), "The Model does not have the tokenizer!!!"
self.bot_name = 'BOT'
template = PROMPT_TEMPLATE['internlm2_chat']
self.template = template
stop_words = []
stop_words += template.get('STOP_WORDS', [])
stop_criteria = get_stop_criteria(
tokenizer=self.tokenizer, stop_words=stop_words)
self.stop_criteria = stop_criteria
default_generation_kwargs = dict(
max_new_tokens=2048,
do_sample=False,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=(
self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None
else self.tokenizer.eos_token_id
),
)
default_generation_kwargs.update(metainfo.get('generation_kwargs', {}))
self.gen_config = GenerationConfig(**default_generation_kwargs)
return
from .load_checkpoint import load_checkpoint_with_prefix, load_state_dict_to_model
from .video_io import VideoReader
import logging
from mmengine.runner.checkpoint import CheckpointLoader
from mmengine.logging.logger import print_log
from huggingface_hub import hf_hub_download
HF_HUB_PREFIX = 'hf-hub:'
def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'):
"""Load partial pretrained model with specific prefix.
Args:
prefix (str): The prefix of sub-module.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str | None): Same as :func:`torch.load`.
Defaults to None.
logger: logger
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
if filename.startswith('hf-hub:'):
model_id = filename[len(HF_HUB_PREFIX):]
filename = hf_hub_download(model_id, 'pytorch_model.bin')
checkpoint = CheckpointLoader.load_checkpoint(filename, map_location=map_location, logger=logger)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
if not prefix:
return state_dict
if not prefix.endswith('.'):
prefix += '.'
prefix_len = len(prefix)
state_dict = {
k[prefix_len:]: v
for k, v in state_dict.items() if k.startswith(prefix)
}
assert state_dict, f'{prefix} is not in the pretrained model'
return state_dict
def load_state_dict_to_model(model, state_dict, logger='current'):
missing_keys, unexpected_keys = model.load_state_dict(state_dict)
if missing_keys:
print_log(missing_keys, logger=logger, level=logging.ERROR)
raise RuntimeError()
if unexpected_keys:
print_log(unexpected_keys, logger=logger, level=logging.ERROR)
raise RuntimeError()
print_log("Loaded checkpoint successfully", logger=logger)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import is_torch_available, logging
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
def _compute_default_rope_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
if config is not None and len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
)
if len(rope_kwargs) > 0:
base = rope_kwargs["base"]
dim = rope_kwargs["dim"]
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0 # Unused in this type of RoPE
# Compute the inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, attention_factor
def _compute_linear_scaling_rope_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
if config is not None and len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
)
if len(rope_kwargs) > 0:
factor = rope_kwargs["factor"]
elif config is not None:
factor = config.rope_scaling["factor"]
# Gets the default RoPE parameters
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
# Then applies linear scaling to the frequencies.
# NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
# applying scaling to the inverse frequencies is equivalent.
inv_freq /= factor
return inv_freq, attention_factor
def _compute_dynamic_ntk_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length, used to update the dynamic RoPE at inference time.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
if config is not None and len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
)
if len(rope_kwargs) > 0:
base = rope_kwargs["base"]
dim = rope_kwargs["dim"]
max_position_embeddings = rope_kwargs["max_position_embeddings"]
factor = rope_kwargs["factor"]
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
max_position_embeddings = config.max_position_embeddings
factor = config.rope_scaling["factor"]
attention_factor = 1.0 # Unused in this type of RoPE
# seq_len: default to max_position_embeddings, e.g. at init time
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
# Compute the inverse frequencies
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, attention_factor
def _compute_yarn_parameters(
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies with NTK scaling. Please refer to the
[original paper](https://arxiv.org/abs/2309.00071)
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# No need to keep BC with yarn, unreleased when this new pattern was created.
if len(rope_kwargs) > 0:
raise ValueError(
f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
)
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
max_position_embeddings = config.max_position_embeddings
factor = config.rope_scaling["factor"]
# Sets the attention factor as suggested in the paper
attention_factor = config.rope_scaling.get("attention_factor")
if attention_factor is None:
attention_factor = 0.1 * math.log(factor) + 1.0
# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
beta_fast = config.rope_scaling.get("beta_fast") or 32
beta_slow = config.rope_scaling.get("beta_slow") or 1
# Compute the inverse frequencies
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
"""Inverse dimension formula to find the dimension based on the number of rotations"""
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
"""Find dimension range bounds based on rotations"""
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1)
def linear_ramp_factor(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# to expand the possible context length. In other words, interpolation = apply scaling factor.
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
)
return inv_freq, attention_factor
def _compute_longrope_parameters(
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies with LongRoPE scaling. Please refer to the
[original implementation](https://github.com/microsoft/LongRoPE)
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
# No need to keep BC with longrope, unreleased when this new pattern was created.
if len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
f"{rope_kwargs}"
)
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
long_factor = config.rope_scaling["long_factor"]
short_factor = config.rope_scaling["short_factor"]
factor = config.rope_scaling.get("factor")
attention_factor = config.rope_scaling.get("attention_factor")
# NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if hasattr(config, "original_max_position_embeddings"):
if seq_len and seq_len < config.original_max_position_embeddings:
expanded_max_position_embeddings = config.original_max_position_embeddings
else:
expanded_max_position_embeddings = config.max_position_embeddings
max_position_embeddings = config.original_max_position_embeddings
factor = expanded_max_position_embeddings / max_position_embeddings
else:
max_position_embeddings = config.max_position_embeddings
expanded_max_position_embeddings = max_position_embeddings * factor
# Sets the attention factor as suggested in the paper
if attention_factor is None:
if factor <= 1.0:
attention_factor = 1.0
else:
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
# Compute the inverse frequencies -- scaled based on the target sequence length
if expanded_max_position_embeddings > max_position_embeddings:
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
else:
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
return inv_freq, attention_factor
def _compute_llama3_parameters(
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies for llama 3.1.
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# Gets the default RoPE parameters
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
factor = config.rope_scaling["factor"] # `8` in the original implementation
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * math.pi / inv_freq
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
return inv_freq_llama, attention_factor
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
# parameterizations, as long as the callable has the same signature.
ROPE_INIT_FUNCTIONS = {
"default": _compute_default_rope_parameters,
"linear": _compute_linear_scaling_rope_parameters,
"dynamic": _compute_dynamic_ntk_parameters,
"yarn": _compute_yarn_parameters,
"longrope": _compute_longrope_parameters,
"llama3": _compute_llama3_parameters,
}
def _check_received_keys(
rope_type: str,
received_keys: set,
required_keys: set,
optional_keys: Optional[set] = None,
ignore_keys: Optional[set] = None,
):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
if "type" in received_keys:
received_keys -= {"type"}
required_keys.add("rope_type")
# Some models need to store model-specific keys, and we don't want to throw warning at them
if ignore_keys is not None:
received_keys -= ignore_keys
missing_keys = required_keys - received_keys
if missing_keys:
raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
if optional_keys is not None:
unused_keys = received_keys - required_keys - optional_keys
else:
unused_keys = received_keys - required_keys
if unused_keys:
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"original_max_position_embeddings"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
beta_fast = rope_scaling.get("beta_fast")
if beta_fast is not None and not isinstance(beta_fast, float):
logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
beta_slow = rope_scaling.get("beta_slow")
if beta_slow is not None and not isinstance(beta_slow, float):
logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
if (beta_fast or 32) < (beta_slow or 1):
logger.warning(
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
)
def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "short_factor", "long_factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
short_factor = rope_scaling.get("short_factor")
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
if not len(short_factor) == dim // 2:
logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
long_factor = rope_scaling.get("long_factor")
if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
if not len(long_factor) == dim // 2:
logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
# `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
# unique to longrope (= undesirable)
if hasattr(config, "original_max_position_embeddings"):
logger.warning_once(
"This model has set a `original_max_position_embeddings` field, to be used together with "
"`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
"with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
"as it is compatible with most model architectures."
)
else:
factor = rope_scaling.get("factor")
if factor is None:
logger.warning("Missing required keys in `rope_scaling`: 'factor'")
elif not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None:
if not isinstance(attention_factor, float) or attention_factor < 0.0:
logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
if low_freq_factor is None or not isinstance(low_freq_factor, float):
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
if high_freq_factor is None or not isinstance(high_freq_factor, float):
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
if high_freq_factor <= low_freq_factor:
logger.warning(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
)
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
logger.warning(
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
f"{original_max_position_embeddings}"
)
if original_max_position_embeddings >= config.max_position_embeddings:
logger.warning(
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
)
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
ROPE_VALIDATION_FUNCTIONS = {
"default": _validate_default_rope_parameters,
"linear": _validate_linear_scaling_rope_parameters,
"dynamic": _validate_dynamic_scaling_rope_parameters,
"yarn": _validate_yarn_parameters,
"longrope": _validate_longrope_parameters,
"llama3": _validate_llama3_parameters,
}
def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
"""
Validate the RoPE config arguments, given a `PretrainedConfig` object
"""
rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
if rope_scaling is None:
return
# BC: "rope_type" was originally "type"
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
if validation_fn is not None:
validation_fn(config, ignore_keys=ignore_keys)
else:
logger.warning(
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
)
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from collections import OrderedDict
import cv2
from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
CAP_PROP_POS_FRAMES)
from mmengine.utils import (check_file_exist, mkdir_or_exist, track_progress)
class Cache:
def __init__(self, capacity):
self._cache = OrderedDict()
self._capacity = int(capacity)
if capacity <= 0:
raise ValueError('capacity must be a positive integer')
@property
def capacity(self):
return self._capacity
@property
def size(self):
return len(self._cache)
def put(self, key, val):
if key in self._cache:
return
if len(self._cache) >= self.capacity:
self._cache.popitem(last=False)
self._cache[key] = val
def get(self, key, default=None):
val = self._cache[key] if key in self._cache else default
return val
class VideoReader:
"""Video class with similar usage to a list object.
This video wrapper class provides convenient apis to access frames.
There exists an issue of OpenCV's VideoCapture class that jumping to a
certain frame may be inaccurate. It is fixed in this class by checking
the position after jumping each time.
Cache is used when decoding videos. So if the same frame is visited for
the second time, there is no need to decode again if it is stored in the
cache.
Examples:
>>> import mmcv
>>> v = mmcv.VideoReader('sample.mp4')
>>> len(v) # get the total frame number with `len()`
120
>>> for img in v: # v is iterable
>>> mmcv.imshow(img)
>>> v[5] # get the 6th frame
"""
def __init__(self, filename, cache_capacity=10):
# Check whether the video path is a url
if not filename.startswith(('https://', 'http://')):
check_file_exist(filename, 'Video file not found: ' + filename)
self._vcap = cv2.VideoCapture(filename)
assert cache_capacity > 0
self._cache = Cache(cache_capacity)
self._position = 0
# get basic info
self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
self._fps = self._vcap.get(CAP_PROP_FPS)
self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
@property
def vcap(self):
""":obj:`cv2.VideoCapture`: The raw VideoCapture object."""
return self._vcap
@property
def opened(self):
"""bool: Indicate whether the video is opened."""
return self._vcap.isOpened()
@property
def width(self):
"""int: Width of video frames."""
return self._width
@property
def height(self):
"""int: Height of video frames."""
return self._height
@property
def resolution(self):
"""tuple: Video resolution (width, height)."""
return (self._width, self._height)
@property
def fps(self):
"""float: FPS of the video."""
return self._fps
@property
def frame_cnt(self):
"""int: Total frames of the video."""
return self._frame_cnt
@property
def fourcc(self):
"""str: "Four character code" of the video."""
return self._fourcc
@property
def position(self):
"""int: Current cursor position, indicating frame decoded."""
return self._position
def _get_real_position(self):
return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
def _set_real_position(self, frame_id):
self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
pos = self._get_real_position()
for _ in range(frame_id - pos):
self._vcap.read()
self._position = frame_id
def read(self):
"""Read the next frame.
If the next frame have been decoded before and in the cache, then
return it directly, otherwise decode, cache and return it.
Returns:
ndarray or None: Return the frame if successful, otherwise None.
"""
# pos = self._position
if self._cache:
img = self._cache.get(self._position)
if img is not None:
ret = True
else:
if self._position != self._get_real_position():
self._set_real_position(self._position)
ret, img = self._vcap.read()
if ret:
self._cache.put(self._position, img)
else:
ret, img = self._vcap.read()
if ret:
self._position += 1
return img
def get_frame(self, frame_id):
"""Get frame by index.
Args:
frame_id (int): Index of the expected frame, 0-based.
Returns:
ndarray or None: Return the frame if successful, otherwise None.
"""
if frame_id < 0 or frame_id >= self._frame_cnt:
raise IndexError(
f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
if frame_id == self._position:
return self.read()
if self._cache:
img = self._cache.get(frame_id)
if img is not None:
self._position = frame_id + 1
return img
self._set_real_position(frame_id)
ret, img = self._vcap.read()
if ret:
if self._cache:
self._cache.put(self._position, img)
self._position += 1
return img
def current_frame(self):
"""Get the current frame (frame that is just visited).
Returns:
ndarray or None: If the video is fresh, return None, otherwise
return the frame.
"""
if self._position == 0:
return None
return self._cache.get(self._position - 1)
def cvt2frames(self,
frame_dir,
file_start=0,
filename_tmpl='{:06d}.jpg',
start=0,
max_num=0,
show_progress=True):
"""Convert a video to frame images.
Args:
frame_dir (str): Output directory to store all the frame images.
file_start (int): Filenames will start from the specified number.
filename_tmpl (str): Filename template with the index as the
placeholder.
start (int): The starting frame index.
max_num (int): Maximum number of frames to be written.
show_progress (bool): Whether to show a progress bar.
"""
mkdir_or_exist(frame_dir)
if max_num == 0:
task_num = self.frame_cnt - start
else:
task_num = min(self.frame_cnt - start, max_num)
if task_num <= 0:
raise ValueError('start must be less than total frame number')
if start > 0:
self._set_real_position(start)
def write_frame(file_idx):
img = self.read()
if img is None:
return
filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
cv2.imwrite(filename, img)
if show_progress:
track_progress(write_frame, range(file_start,
file_start + task_num))
else:
for i in range(task_num):
write_frame(file_start + i)
def __len__(self):
return self.frame_cnt
def __getitem__(self, index):
if isinstance(index, slice):
return [
self.get_frame(i)
for i in range(*index.indices(self.frame_cnt))
]
# support negative indexing
if index < 0:
index += self.frame_cnt
if index < 0:
raise IndexError('index out of range')
return self.get_frame(index)
def __iter__(self):
self._set_real_position(0)
return self
def __next__(self):
img = self.read()
if img is not None:
return img
else:
raise StopIteration
next = __next__
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self._vcap.release()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment