Commit 17d316f3 authored by suily's avatar suily
Browse files

Initial commit

parents
Pipeline #3368 failed with stages
in 0 seconds
BSD 3-Clause License
Copyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
Copyright 2023 Rex Cheng
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
# Segment Anything Video (SA-V) Dataset
## Overview
[Segment Anything Video (SA-V)](https://ai.meta.com/datasets/segment-anything-video/), consists of 51K diverse videos and 643K high-quality spatio-temporal segmentation masks (i.e., masklets). The dataset is released under the CC by 4.0 license. Browse the dataset [here](https://sam2.metademolab.com/dataset).
![SA-V dataset](../assets/sa_v_dataset.jpg?raw=true)
## Getting Started
### Download the dataset
Visit [here](https://ai.meta.com/datasets/segment-anything-video-downloads/) to download SA-V including the training, val and test sets.
### Dataset Stats
| | Num Videos | Num Masklets |
| ---------- | ---------- | ----------------------------------------- |
| SA-V train | 50,583 | 642,036 (auto 451,720 and manual 190,316) |
| SA-V val | 155 | 293 |
| SA-V test | 150 | 278 |
### Notebooks
To load and visualize the SA-V training set annotations, refer to the example [sav_visualization_example.ipynb](./sav_visualization_example.ipynb) notebook.
### SA-V train
For SA-V training set we release the mp4 videos and store the masklet annotations per video as json files . Automatic masklets and manual masklets are stored separately as two json files: `{video_id}_auto.json` and `{video_id}_manual.json`. They can be loaded as dictionaries in python in the format below.
```
{
"video_id" : str; video id
"video_duration" : float64; the duration in seconds of this video
"video_frame_count" : float64; the number of frames in the video
"video_height" : float64; the height of the video
"video_width" : float64; the width of the video
"video_resolution" : float64; video_height $\times$ video_width
"video_environment" : List[str]; "Indoor" or "Outdoor"
"video_split" : str; "train" for training set
"masklet" : List[List[Dict]]; masklet annotations in list of list of RLEs.
The outer list is over frames in the video and the inner list
is over objects in the video.
"masklet_id" : List[int]; the masklet ids
"masklet_size_rel" : List[float]; the average mask area normalized by resolution
across all the frames where the object is visible
"masklet_size_abs" : List[float]; the average mask area (in pixels)
across all the frames where the object is visible
"masklet_size_bucket" : List[str]; "small": $1$ <= masklet_size_abs < $32^2$,
"medium": $32^2$ <= masklet_size_abs < $96^2$,
and "large": masklet_size_abs > $96^2$
"masklet_visibility_changes" : List[int]; the number of times where the visibility changes
after the first appearance (e.g., invisible -> visible
or visible -> invisible)
"masklet_first_appeared_frame" : List[int]; the index of the frame where the object appears
the first time in the video. Always 0 for auto masklets.
"masklet_frame_count" : List[int]; the number of frames being annotated. Note that
videos are annotated at 6 fps (annotated every 4 frames)
while the videos are at 24 fps.
"masklet_edited_frame_count" : List[int]; the number of frames being edited by human annotators.
Always 0 for auto masklets.
"masklet_type" : List[str]; "auto" or "manual"
"masklet_stability_score" : Optional[List[List[float]]]; per-mask stability scores. Auto annotation only.
"masklet_num" : int; the number of manual/auto masklets in the video
}
```
Note that in SA-V train, there are in total 50,583 videos where all of them have manual annotations. Among the 50,583 videos there are 48,436 videos that also have automatic annotations.
### SA-V val and test
For SA-V val and test sets, we release the extracted frames as jpeg files, and the masks as png files with the following directory structure:
```
sav_val(sav_test)
├── sav_val.txt (sav_test.txt): a list of video ids in the split
├── JPEGImages_24fps # videos are extracted at 24 fps
│ ├── {video_id}
│ │ ├── 00000.jpg # video frame
│ │ ├── 00001.jpg # video frame
│ │ ├── 00002.jpg # video frame
│ │ ├── 00003.jpg # video frame
│ │ └── ...
│ ├── {video_id}
│ ├── {video_id}
│ └── ...
└── Annotations_6fps # videos are annotated at 6 fps
├── {video_id}
│ ├── 000 # obj 000
│ │ ├── 00000.png # mask for object 000 in 00000.jpg
│ │ ├── 00004.png # mask for object 000 in 00004.jpg
│ │ ├── 00008.png # mask for object 000 in 00008.jpg
│ │ ├── 00012.png # mask for object 000 in 00012.jpg
│ │ └── ...
│ ├── 001 # obj 001
│ ├── 002 # obj 002
│ └── ...
├── {video_id}
├── {video_id}
└── ...
```
All masklets in val and test sets are manually annotated in every frame by annotators. For each annotated object in a video, we store the annotated masks in a single png. This is because the annotated objects may overlap, e.g., it is possible in our SA-V dataset for there to be a mask for the whole person as well as a separate mask for their hands.
## SA-V Val and Test Evaluation
We provide an evaluator to compute the common J and F metrics on SA-V val and test sets. To run the evaluation, we need to first install a few dependencies as follows:
```
pip install -r requirements.txt
```
Then we can evaluate the predictions as follows:
```
python sav_evaluator.py --gt_root {GT_ROOT} --pred_root {PRED_ROOT}
```
or run
```
python sav_evaluator.py --help
```
to print a complete help message.
The evaluator expects the `GT_ROOT` to be one of the following folder structures, and `GT_ROOT` and `PRED_ROOT` to have the same structure.
- Same as SA-V val and test directory structure
```
{GT_ROOT} # gt root folder
├── {video_id}
│ ├── 000 # all masks associated with obj 000
│ │ ├── 00000.png # mask for object 000 in frame 00000 (binary mask)
│ │ └── ...
│ ├── 001 # all masks associated with obj 001
│ ├── 002 # all masks associated with obj 002
│ └── ...
├── {video_id}
├── {video_id}
└── ...
```
In the paper for the experiments on SA-V val and test, we run inference on the 24 fps videos, and evaluate on the subset of frames where we have ground truth annotations (first and last annotated frames dropped). The evaluator will ignore the masks in frames where we don't have ground truth annotations.
- Same as [DAVIS](https://github.com/davisvideochallenge/davis2017-evaluation) directory structure
```
{GT_ROOT} # gt root folder
├── {video_id}
│ ├── 00000.png # annotations in frame 00000 (may contain multiple objects)
│ └── ...
├── {video_id}
├── {video_id}
└── ...
```
## License
The evaluation code is licensed under the [BSD 3 license](./LICENSE). Please refer to the paper for more details on the models. The videos and annotations in SA-V Dataset are released under CC BY 4.0.
Third-party code: the evaluation software is heavily adapted from [`VOS-Benchmark`](https://github.com/hkchengrex/vos-benchmark) and [`DAVIS`](https://github.com/davisvideochallenge/davis2017-evaluation) (with their licenses in [`LICENSE_DAVIS`](./LICENSE_DAVIS) and [`LICENSE_VOS_BENCHMARK`](./LICENSE_VOS_BENCHMARK)).
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
pycocoevalcap
scikit-image
opencv-python
tqdm
pillow
numpy
matplotlib
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the sav_dataset directory of this source tree.
# adapted from https://github.com/hkchengrex/vos-benchmark
# and https://github.com/davisvideochallenge/davis2017-evaluation
# with their licenses found in the LICENSE_VOS_BENCHMARK and LICENSE_DAVIS files
# in the sav_dataset directory.
from argparse import ArgumentParser
from utils.sav_benchmark import benchmark
"""
The structure of the {GT_ROOT} can be either of the follow two structures.
{GT_ROOT} and {PRED_ROOT} should be of the same format
1. SA-V val/test structure
{GT_ROOT} # gt root folder
├── {video_id}
│ ├── 000 # all masks associated with obj 000
│ │ ├── {frame_id}.png # mask for object 000 in {frame_id} (binary mask)
│ │ └── ...
│ ├── 001 # all masks associated with obj 001
│ ├── 002 # all masks associated with obj 002
│ └── ...
├── {video_id}
├── {video_id}
└── ...
2. Similar to DAVIS structure:
{GT_ROOT} # gt root folder
├── {video_id}
│ ├── {frame_id}.png # annotation in {frame_id} (may contain multiple objects)
│ └── ...
├── {video_id}
├── {video_id}
└── ...
"""
parser = ArgumentParser()
parser.add_argument(
"--gt_root",
required=True,
help="Path to the GT folder. For SA-V, it's sav_val/Annotations_6fps or sav_test/Annotations_6fps",
)
parser.add_argument(
"--pred_root",
required=True,
help="Path to a folder containing folders of masks to be evaluated, with exactly the same structure as gt_root",
)
parser.add_argument(
"-n", "--num_processes", default=16, type=int, help="Number of concurrent processes"
)
parser.add_argument(
"-s",
"--strict",
help="Make sure every video in the gt_root folder has a corresponding video in the prediction",
action="store_true",
)
parser.add_argument(
"-q",
"--quiet",
help="Quietly run evaluation without printing the information out",
action="store_true",
)
# https://github.com/davisvideochallenge/davis2017-evaluation/blob/d34fdef71ce3cb24c1a167d860b707e575b3034c/davis2017/evaluation.py#L85
parser.add_argument(
"--do_not_skip_first_and_last_frame",
help="In SA-V val and test, we skip the first and the last annotated frames in evaluation. "
"Set this to true for evaluation on settings that doesn't skip first and last frames",
action="store_true",
)
if __name__ == "__main__":
args = parser.parse_args()
benchmark(
[args.gt_root],
[args.pred_root],
args.strict,
args.num_processes,
verbose=not args.quiet,
skip_first_and_last=not args.do_not_skip_first_and_last_frame,
)
This source diff could not be displayed because it is too large. You can view the blob instead.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the sav_dataset directory of this source tree.
# adapted from https://github.com/hkchengrex/vos-benchmark
# and https://github.com/davisvideochallenge/davis2017-evaluation
# with their licenses found in the LICENSE_VOS_BENCHMARK and LICENSE_DAVIS files
# in the sav_dataset directory.
import math
import os
import time
from collections import defaultdict
from multiprocessing import Pool
from os import path
from typing import Dict, List, Tuple
import cv2
import numpy as np
import tqdm
from PIL import Image
from skimage.morphology import disk
class VideoEvaluator:
def __init__(self, gt_root, pred_root, skip_first_and_last=True) -> None:
"""
gt_root: path to the folder storing the gt masks
pred_root: path to the folder storing the predicted masks
skip_first_and_last: whether we should skip the evaluation of the first and the last frame.
True for SA-V val and test, same as in DAVIS semi-supervised evaluation.
"""
self.gt_root = gt_root
self.pred_root = pred_root
self.skip_first_and_last = skip_first_and_last
def __call__(self, vid_name: str) -> Tuple[str, Dict[str, float], Dict[str, float]]:
"""
vid_name: name of the video to evaluate
"""
# scan the folder to find subfolders for evaluation and
# check if the folder structure is SA-V
to_evaluate, is_sav_format = self.scan_vid_folder(vid_name)
# evaluate each (gt_path, pred_path) pair
eval_results = []
for all_frames, obj_id, gt_path, pred_path in to_evaluate:
if self.skip_first_and_last:
# skip the first and the last frames
all_frames = all_frames[1:-1]
evaluator = Evaluator(name=vid_name, obj_id=obj_id)
for frame in all_frames:
gt_array, pred_array = self.get_gt_and_pred(
gt_path, pred_path, frame, is_sav_format
)
evaluator.feed_frame(mask=pred_array, gt=gt_array)
iou, boundary_f = evaluator.conclude()
eval_results.append((obj_id, iou, boundary_f))
if is_sav_format:
iou_output, boundary_f_output = self.consolidate(eval_results)
else:
assert len(eval_results) == 1
iou_output = eval_results[0][1]
boundary_f_output = eval_results[0][2]
return vid_name, iou_output, boundary_f_output
def get_gt_and_pred(
self,
gt_path: str,
pred_path: str,
f_name: str,
is_sav_format: bool,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Get the ground-truth and predicted masks for a single frame.
"""
gt_mask_path = path.join(gt_path, f_name)
pred_mask_path = path.join(pred_path, f_name)
assert os.path.exists(pred_mask_path), f"{pred_mask_path} not found"
gt_array = np.array(Image.open(gt_mask_path))
pred_array = np.array(Image.open(pred_mask_path))
assert (
gt_array.shape[-2:] == pred_array.shape[-2:]
), f"shape mismatch: {gt_mask_path}, {pred_mask_path}"
if is_sav_format:
assert len(np.unique(gt_array)) <= 2, (
f"found more than 1 object in {gt_mask_path} "
"SA-V format assumes one object mask per png file."
)
assert len(np.unique(pred_array)) <= 2, (
f"found more than 1 object in {pred_mask_path} "
"SA-V format assumes one object mask per png file."
)
gt_array = gt_array > 0
pred_array = pred_array > 0
return gt_array, pred_array
def scan_vid_folder(self, vid_name) -> Tuple[List, bool]:
"""
Scan the folder structure of the video and return a list of folders for evaluate.
"""
vid_gt_path = path.join(self.gt_root, vid_name)
vid_pred_path = path.join(self.pred_root, vid_name)
all_files_and_dirs = sorted(os.listdir(vid_gt_path))
to_evaluate = []
if all(name.endswith(".png") for name in all_files_and_dirs):
# All files are png files, dataset structure similar to DAVIS
is_sav_format = False
frames = all_files_and_dirs
obj_dir = None
to_evaluate.append((frames, obj_dir, vid_gt_path, vid_pred_path))
else:
# SA-V dataset structure, going one layer down into each subdirectory
is_sav_format = True
for obj_dir in all_files_and_dirs:
obj_gt_path = path.join(vid_gt_path, obj_dir)
obj_pred_path = path.join(vid_pred_path, obj_dir)
frames = sorted(os.listdir(obj_gt_path))
to_evaluate.append((frames, obj_dir, obj_gt_path, obj_pred_path))
return to_evaluate, is_sav_format
def consolidate(
self, eval_results
) -> Tuple[str, Dict[str, float], Dict[str, float]]:
"""
Consolidate the results of all the objects from the video into one dictionary.
"""
iou_output = {}
boundary_f_output = {}
for obj_id, iou, boundary_f in eval_results:
assert len(iou) == 1
key = list(iou.keys())[0]
iou_output[obj_id] = iou[key]
boundary_f_output[obj_id] = boundary_f[key]
return iou_output, boundary_f_output
#################################################################################################################
# Functions below are from https://github.com/hkchengrex/vos-benchmark with minor modifications
# _seg2bmap from https://github.com/hkchengrex/vos-benchmark/blob/main/vos_benchmark/utils.py
# get_iou and Evaluator from https://github.com/hkchengrex/vos-benchmark/blob/main/vos_benchmark/evaluator.py
# benchmark from https://github.com/hkchengrex/vos-benchmark/blob/main/vos_benchmark/benchmark.py with slight mod
#################################################################################################################
def _seg2bmap(seg, width=None, height=None):
"""
From a segmentation, compute a binary boundary map with 1 pixel wide
boundaries. The boundary pixels are offset by 1/2 pixel towards the
origin from the actual segment boundary.
Arguments:
seg : Segments labeled from 1..k.
width : Width of desired bmap <= seg.shape[1]
height : Height of desired bmap <= seg.shape[0]
Returns:
bmap (ndarray): Binary boundary map.
David Martin <dmartin@eecs.berkeley.edu>
January 2003
"""
seg = seg.astype(bool)
seg[seg > 0] = 1
assert np.atleast_3d(seg).shape[2] == 1
width = seg.shape[1] if width is None else width
height = seg.shape[0] if height is None else height
h, w = seg.shape[:2]
ar1 = float(width) / float(height)
ar2 = float(w) / float(h)
assert not (
width > w | height > h | abs(ar1 - ar2) > 0.01
), "Cannot convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
e = np.zeros_like(seg)
s = np.zeros_like(seg)
se = np.zeros_like(seg)
e[:, :-1] = seg[:, 1:]
s[:-1, :] = seg[1:, :]
se[:-1, :-1] = seg[1:, 1:]
b = seg ^ e | seg ^ s | seg ^ se
b[-1, :] = seg[-1, :] ^ e[-1, :]
b[:, -1] = seg[:, -1] ^ s[:, -1]
b[-1, -1] = 0
if w == width and h == height:
bmap = b
else:
bmap = np.zeros((height, width))
for x in range(w):
for y in range(h):
if b[y, x]:
j = 1 + math.floor((y - 1) + height / h)
i = 1 + math.floor((x - 1) + width / h)
bmap[j, i] = 1
return bmap
def get_iou(intersection, pixel_sum):
# handle edge cases without resorting to epsilon
if intersection == pixel_sum:
# both mask and gt have zero pixels in them
assert intersection == 0
return 1
return intersection / (pixel_sum - intersection)
class Evaluator:
def __init__(self, boundary=0.008, name=None, obj_id=None):
# boundary: used in computing boundary F-score
self.boundary = boundary
self.name = name
self.obj_id = obj_id
self.objects_in_gt = set()
self.objects_in_masks = set()
self.object_iou = defaultdict(list)
self.boundary_f = defaultdict(list)
def feed_frame(self, mask: np.ndarray, gt: np.ndarray):
"""
Compute and accumulate metrics for a single frame (mask/gt pair)
"""
# get all objects in the ground-truth
gt_objects = np.unique(gt)
gt_objects = gt_objects[gt_objects != 0].tolist()
# get all objects in the predicted mask
mask_objects = np.unique(mask)
mask_objects = mask_objects[mask_objects != 0].tolist()
self.objects_in_gt.update(set(gt_objects))
self.objects_in_masks.update(set(mask_objects))
all_objects = self.objects_in_gt.union(self.objects_in_masks)
# boundary disk for boundary F-score. It is the same for all objects.
bound_pix = np.ceil(self.boundary * np.linalg.norm(mask.shape))
boundary_disk = disk(bound_pix)
for obj_idx in all_objects:
obj_mask = mask == obj_idx
obj_gt = gt == obj_idx
# object iou
self.object_iou[obj_idx].append(
get_iou((obj_mask * obj_gt).sum(), obj_mask.sum() + obj_gt.sum())
)
"""
# boundary f-score
This part is copied from davis2017-evaluation
"""
mask_boundary = _seg2bmap(obj_mask)
gt_boundary = _seg2bmap(obj_gt)
mask_dilated = cv2.dilate(mask_boundary.astype(np.uint8), boundary_disk)
gt_dilated = cv2.dilate(gt_boundary.astype(np.uint8), boundary_disk)
# Get the intersection
gt_match = gt_boundary * mask_dilated
fg_match = mask_boundary * gt_dilated
# Area of the intersection
n_fg = np.sum(mask_boundary)
n_gt = np.sum(gt_boundary)
# Compute precision and recall
if n_fg == 0 and n_gt > 0:
precision = 1
recall = 0
elif n_fg > 0 and n_gt == 0:
precision = 0
recall = 1
elif n_fg == 0 and n_gt == 0:
precision = 1
recall = 1
else:
precision = np.sum(fg_match) / float(n_fg)
recall = np.sum(gt_match) / float(n_gt)
# Compute F measure
if precision + recall == 0:
F = 0
else:
F = 2 * precision * recall / (precision + recall)
self.boundary_f[obj_idx].append(F)
def conclude(self):
all_iou = {}
all_boundary_f = {}
for object_id in self.objects_in_gt:
all_iou[object_id] = np.mean(self.object_iou[object_id]) * 100
all_boundary_f[object_id] = np.mean(self.boundary_f[object_id]) * 100
return all_iou, all_boundary_f
def benchmark(
gt_roots,
mask_roots,
strict=True,
num_processes=None,
*,
verbose=True,
skip_first_and_last=True,
):
"""
gt_roots: a list of paths to datasets, i.e., [path_to_DatasetA, path_to_DatasetB, ...]
mask_roots: same as above, but the .png are masks predicted by the model
strict: when True, all videos in the dataset must have corresponding predictions.
Setting it to False is useful in cases where the ground-truth contains both train/val
sets, but the model only predicts the val subset.
Either way, if a video is predicted (i.e., the corresponding folder exists),
then it must at least contain all the masks in the ground truth annotations.
Masks that are in the prediction but not in the ground-truth
(i.e., sparse annotations) are ignored.
skip_first_and_last: whether we should skip the first and the last frame in evaluation.
This is used by DAVIS 2017 in their semi-supervised evaluation.
It should be disabled for unsupervised evaluation.
"""
assert len(gt_roots) == len(mask_roots)
single_dataset = len(gt_roots) == 1
if verbose:
if skip_first_and_last:
print(
"We are *SKIPPING* the evaluation of the first and the last frame (standard for semi-supervised video object segmentation)."
)
else:
print(
"We are *NOT SKIPPING* the evaluation of the first and the last frame (*NOT STANDARD* for semi-supervised video object segmentation)."
)
pool = Pool(num_processes)
start = time.time()
to_wait = []
for gt_root, mask_root in zip(gt_roots, mask_roots):
# Validate folders
validated = True
gt_videos = os.listdir(gt_root)
mask_videos = os.listdir(mask_root)
# if the user passed the root directory instead of Annotations
if len(gt_videos) != len(mask_videos):
if "Annotations" in gt_videos:
if ".png" not in os.listdir(path.join(gt_root, "Annotations"))[0]:
gt_root = path.join(gt_root, "Annotations")
gt_videos = os.listdir(gt_root)
# remove non-folder items
gt_videos = list(filter(lambda x: path.isdir(path.join(gt_root, x)), gt_videos))
mask_videos = list(
filter(lambda x: path.isdir(path.join(mask_root, x)), mask_videos)
)
if not strict:
videos = sorted(list(set(gt_videos) & set(mask_videos)))
else:
gt_extras = set(gt_videos) - set(mask_videos)
mask_extras = set(mask_videos) - set(gt_videos)
if len(gt_extras) > 0:
print(
f"Videos that are in {gt_root} but not in {mask_root}: {gt_extras}"
)
validated = False
if len(mask_extras) > 0:
print(
f"Videos that are in {mask_root} but not in {gt_root}: {mask_extras}"
)
validated = False
if not validated:
print("Validation failed. Exiting.")
exit(1)
videos = sorted(gt_videos)
if verbose:
print(
f"In dataset {gt_root}, we are evaluating on {len(videos)} videos: {videos}"
)
if single_dataset:
if verbose:
results = tqdm.tqdm(
pool.imap(
VideoEvaluator(
gt_root, mask_root, skip_first_and_last=skip_first_and_last
),
videos,
),
total=len(videos),
)
else:
results = pool.map(
VideoEvaluator(
gt_root, mask_root, skip_first_and_last=skip_first_and_last
),
videos,
)
else:
to_wait.append(
pool.map_async(
VideoEvaluator(
gt_root, mask_root, skip_first_and_last=skip_first_and_last
),
videos,
)
)
pool.close()
all_global_jf, all_global_j, all_global_f = [], [], []
all_object_metrics = []
for i, mask_root in enumerate(mask_roots):
if not single_dataset:
results = to_wait[i].get()
all_iou = []
all_boundary_f = []
object_metrics = {}
for name, iou, boundary_f in results:
all_iou.extend(list(iou.values()))
all_boundary_f.extend(list(boundary_f.values()))
object_metrics[name] = (iou, boundary_f)
global_j = np.array(all_iou).mean()
global_f = np.array(all_boundary_f).mean()
global_jf = (global_j + global_f) / 2
time_taken = time.time() - start
"""
Build string for reporting results
"""
# find max length for padding
ml = max(*[len(n) for n in object_metrics.keys()], len("Global score"))
# build header
out_string = f'{"sequence":<{ml}},{"obj":>3}, {"J&F":>4}, {"J":>4}, {"F":>4}\n'
out_string += f'{"Global score":<{ml}},{"":>3}, {global_jf:.1f}, {global_j:.1f}, {global_f:.1f}\n'
# append one line for each object
for name, (iou, boundary_f) in object_metrics.items():
for object_idx in iou.keys():
j, f = iou[object_idx], boundary_f[object_idx]
jf = (j + f) / 2
out_string += (
f"{name:<{ml}},{object_idx:03}, {jf:>4.1f}, {j:>4.1f}, {f:>4.1f}\n"
)
# print to console
if verbose:
print(out_string.replace(",", " "), end="")
print("\nSummary:")
print(
f"Global score: J&F: {global_jf:.1f} J: {global_j:.1f} F: {global_f:.1f}"
)
print(f"Time taken: {time_taken:.2f}s")
# print to file
result_path = path.join(mask_root, "results.csv")
print(f"Saving the results to {result_path}")
with open(result_path, "w") as f:
f.write(out_string)
all_global_jf.append(global_jf)
all_global_j.append(global_j)
all_global_f.append(global_f)
all_object_metrics.append(object_metrics)
return all_global_jf, all_global_j, all_global_f, all_object_metrics
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the sav_dataset directory of this source tree.
import json
import os
from typing import Dict, List, Optional, Tuple
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pycocotools.mask as mask_util
def decode_video(video_path: str) -> List[np.ndarray]:
"""
Decode the video and return the RGB frames
"""
video = cv2.VideoCapture(video_path)
video_frames = []
while video.isOpened():
ret, frame = video.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
video_frames.append(frame)
else:
break
return video_frames
def show_anns(masks, colors: List, borders=True) -> None:
"""
show the annotations
"""
# return if no masks
if len(masks) == 0:
return
# sort masks by size
sorted_annot_and_color = sorted(
zip(masks, colors), key=(lambda x: x[0].sum()), reverse=True
)
H, W = sorted_annot_and_color[0][0].shape[0], sorted_annot_and_color[0][0].shape[1]
canvas = np.ones((H, W, 4))
canvas[:, :, 3] = 0 # set the alpha channel
contour_thickness = max(1, int(min(5, 0.01 * min(H, W))))
for mask, color in sorted_annot_and_color:
canvas[mask] = np.concatenate([color, [0.55]])
if borders:
contours, _ = cv2.findContours(
np.array(mask, dtype=np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE
)
cv2.drawContours(
canvas, contours, -1, (0.05, 0.05, 0.05, 1), thickness=contour_thickness
)
ax = plt.gca()
ax.imshow(canvas)
class SAVDataset:
"""
SAVDataset is a class to load the SAV dataset and visualize the annotations.
"""
def __init__(self, sav_dir, annot_sample_rate=4):
"""
Args:
sav_dir: the directory of the SAV dataset
annot_sample_rate: the sampling rate of the annotations.
The annotations are aligned with the videos at 6 fps.
"""
self.sav_dir = sav_dir
self.annot_sample_rate = annot_sample_rate
self.manual_mask_colors = np.random.random((256, 3))
self.auto_mask_colors = np.random.random((256, 3))
def read_frames(self, mp4_path: str) -> None:
"""
Read the frames and downsample them to align with the annotations.
"""
if not os.path.exists(mp4_path):
print(f"{mp4_path} doesn't exist.")
return None
else:
# decode the video
frames = decode_video(mp4_path)
print(f"There are {len(frames)} frames decoded from {mp4_path} (24fps).")
# downsample the frames to align with the annotations
frames = frames[:: self.annot_sample_rate]
print(
f"Videos are annotated every {self.annot_sample_rate} frames. "
"To align with the annotations, "
f"downsample the video to {len(frames)} frames."
)
return frames
def get_frames_and_annotations(
self, video_id: str
) -> Tuple[List | None, Dict | None, Dict | None]:
"""
Get the frames and annotations for video.
"""
# load the video
mp4_path = os.path.join(self.sav_dir, video_id + ".mp4")
frames = self.read_frames(mp4_path)
if frames is None:
return None, None, None
# load the manual annotations
manual_annot_path = os.path.join(self.sav_dir, video_id + "_manual.json")
if not os.path.exists(manual_annot_path):
print(f"{manual_annot_path} doesn't exist. Something might be wrong.")
manual_annot = None
else:
manual_annot = json.load(open(manual_annot_path))
# load the manual annotations
auto_annot_path = os.path.join(self.sav_dir, video_id + "_auto.json")
if not os.path.exists(auto_annot_path):
print(f"{auto_annot_path} doesn't exist.")
auto_annot = None
else:
auto_annot = json.load(open(auto_annot_path))
return frames, manual_annot, auto_annot
def visualize_annotation(
self,
frames: List[np.ndarray],
auto_annot: Optional[Dict],
manual_annot: Optional[Dict],
annotated_frame_id: int,
show_auto=True,
show_manual=True,
) -> None:
"""
Visualize the annotations on the annotated_frame_id.
If show_manual is True, show the manual annotations.
If show_auto is True, show the auto annotations.
By default, show both auto and manual annotations.
"""
if annotated_frame_id >= len(frames):
print("invalid annotated_frame_id")
return
rles = []
colors = []
if show_manual and manual_annot is not None:
rles.extend(manual_annot["masklet"][annotated_frame_id])
colors.extend(
self.manual_mask_colors[
: len(manual_annot["masklet"][annotated_frame_id])
]
)
if show_auto and auto_annot is not None:
rles.extend(auto_annot["masklet"][annotated_frame_id])
colors.extend(
self.auto_mask_colors[: len(auto_annot["masklet"][annotated_frame_id])]
)
plt.imshow(frames[annotated_frame_id])
if len(rles) > 0:
masks = [mask_util.decode(rle) > 0 for rle in rles]
show_anns(masks, colors)
else:
print("No annotation will be shown")
plt.axis("off")
plt.show()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from setuptools import find_packages, setup
# Package metadata
NAME = "SAM-2"
VERSION = "1.0"
DESCRIPTION = "SAM 2: Segment Anything in Images and Videos"
URL = "https://github.com/facebookresearch/sam2"
AUTHOR = "Meta AI"
AUTHOR_EMAIL = "segment-anything@meta.com"
LICENSE = "Apache 2.0"
# Read the contents of README file
with open("README.md", "r", encoding="utf-8") as f:
LONG_DESCRIPTION = f.read()
# Required dependencies
REQUIRED_PACKAGES = [
"torch>=2.5.1",
"torchvision>=0.20.1",
"numpy>=1.24.4",
"tqdm>=4.66.1",
"hydra-core>=1.3.2",
"iopath>=0.1.10",
"pillow>=9.4.0",
]
EXTRA_PACKAGES = {
"notebooks": [
"matplotlib>=3.9.1",
"jupyter>=1.0.0",
"opencv-python>=4.7.0",
"eva-decord>=0.6.1",
],
"interactive-demo": [
"Flask>=3.0.3",
"Flask-Cors>=5.0.0",
"av>=13.0.0",
"dataclasses-json>=0.6.7",
"eva-decord>=0.6.1",
"gunicorn>=23.0.0",
"imagesize>=1.4.1",
"pycocotools>=2.0.8",
"strawberry-graphql>=0.243.0",
],
"dev": [
"black==24.2.0",
"usort==1.0.2",
"ufmt==2.0.0b2",
"fvcore>=0.1.5.post20221221",
"pandas>=2.2.2",
"scikit-image>=0.24.0",
"tensorboard>=2.17.0",
"pycocotools>=2.0.8",
"tensordict>=0.6.0",
"opencv-python>=4.7.0",
"submitit>=1.5.1",
],
}
# By default, we also build the SAM 2 CUDA extension.
# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
# By default, we allow SAM 2 installation to proceed even with build errors.
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
# Catch and skip errors during extension building and print a warning message
# (note that this message only shows up under verbose build mode
# "pip install -v -e ." or "python setup.py build_ext -v")
CUDA_ERROR_MSG = (
"{}\n\n"
"Failed to build the SAM 2 CUDA extension due to the error above. "
"You can still use SAM 2 and it's OK to ignore the error above, although some "
"post-processing functionality may be limited (which doesn't affect the results in most cases; "
"(see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n"
)
def get_extensions():
if not BUILD_CUDA:
return []
try:
from torch.utils.cpp_extension import CUDAExtension
srcs = ["sam2/csrc/connected_components.cu"]
compile_args = {
"cxx": [],
"nvcc": [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
}
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
except Exception as e:
if BUILD_ALLOW_ERRORS:
print(CUDA_ERROR_MSG.format(e))
ext_modules = []
else:
raise e
return ext_modules
try:
from torch.utils.cpp_extension import BuildExtension
class BuildExtensionIgnoreErrors(BuildExtension):
def finalize_options(self):
try:
super().finalize_options()
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
def build_extensions(self):
try:
super().build_extensions()
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
def get_ext_filename(self, ext_name):
try:
return super().get_ext_filename(ext_name)
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
return "_C.so"
cmdclass = {
"build_ext": (
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
if BUILD_ALLOW_ERRORS
else BuildExtension.with_options(no_python_abi_suffix=True)
)
}
except Exception as e:
cmdclass = {}
if BUILD_ALLOW_ERRORS:
print(CUDA_ERROR_MSG.format(e))
else:
raise e
# Setup configuration
setup(
name=NAME,
version=VERSION,
description=DESCRIPTION,
long_description=LONG_DESCRIPTION,
long_description_content_type="text/markdown",
url=URL,
author=AUTHOR,
author_email=AUTHOR_EMAIL,
license=LICENSE,
packages=find_packages(exclude="notebooks"),
include_package_data=True,
install_requires=REQUIRED_PACKAGES,
extras_require=EXTRA_PACKAGES,
python_requires=">=3.10.0",
ext_modules=get_extensions(),
cmdclass=cmdclass,
)
## SAM 2 toolkits
This directory provides toolkits for additional SAM 2 use cases.
### Semi-supervised VOS inference
The `vos_inference.py` script can be used to generate predictions for semi-supervised video object segmentation (VOS) evaluation on datasets such as [DAVIS](https://davischallenge.org/index.html), [MOSE](https://henghuiding.github.io/MOSE/) or the SA-V dataset.
After installing SAM 2 and its dependencies, it can be used as follows ([DAVIS 2017 dataset](https://davischallenge.org/davis2017/code.html) as an example). This script saves the prediction PNG files to the `--output_mask_dir`.
```bash
python ./tools/vos_inference.py \
--sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \
--sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \
--base_video_dir /path-to-davis-2017/JPEGImages/480p \
--input_mask_dir /path-to-davis-2017/Annotations/480p \
--video_list_file /path-to-davis-2017/ImageSets/2017/val.txt \
--output_mask_dir ./outputs/davis_2017_pred_pngs
```
(replace `/path-to-davis-2017` with the path to DAVIS 2017 dataset)
To evaluate on the SA-V dataset with per-object PNG files for the object masks, we need to **add the `--per_obj_png_file` flag** as follows (using SA-V val as an example). This script will also save per-object PNG files for the output masks under the `--per_obj_png_file` flag.
```bash
python ./tools/vos_inference.py \
--sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \
--sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \
--base_video_dir /path-to-sav-val/JPEGImages_24fps \
--input_mask_dir /path-to-sav-val/Annotations_6fps \
--video_list_file /path-to-sav-val/sav_val.txt \
--per_obj_png_file \
--output_mask_dir ./outputs/sav_val_pred_pngs
```
(replace `/path-to-sav-val` with the path to SA-V val)
Then, we can use the evaluation tools or servers for each dataset to get the performance of the prediction PNG files above.
Note: by default, the `vos_inference.py` script above assumes that all objects to track already appear on frame 0 in each video (as is the case in DAVIS, MOSE or SA-V). **For VOS datasets that don't have all objects to track appearing in the first frame (such as LVOS or YouTube-VOS), please add the `--track_object_appearing_later_in_video` flag when using `vos_inference.py`**.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
from collections import defaultdict
import numpy as np
import torch
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor
# the PNG palette for DAVIS 2017 dataset
DAVIS_PALETTE = b"\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0"
def load_ann_png(path):
"""Load a PNG file as a mask and its palette."""
mask = Image.open(path)
palette = mask.getpalette()
mask = np.array(mask).astype(np.uint8)
return mask, palette
def save_ann_png(path, mask, palette):
"""Save a mask as a PNG file with the given palette."""
assert mask.dtype == np.uint8
assert mask.ndim == 2
output_mask = Image.fromarray(mask)
output_mask.putpalette(palette)
output_mask.save(path)
def get_per_obj_mask(mask):
"""Split a mask into per-object masks."""
object_ids = np.unique(mask)
object_ids = object_ids[object_ids > 0].tolist()
per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids}
return per_obj_mask
def put_per_obj_mask(per_obj_mask, height, width):
"""Combine per-object masks into a single mask."""
mask = np.zeros((height, width), dtype=np.uint8)
object_ids = sorted(per_obj_mask)[::-1]
for object_id in object_ids:
object_mask = per_obj_mask[object_id]
object_mask = object_mask.reshape(height, width)
mask[object_mask] = object_id
return mask
def load_masks_from_dir(
input_mask_dir, video_name, frame_name, per_obj_png_file, allow_missing=False
):
"""Load masks from a directory as a dict of per-object masks."""
if not per_obj_png_file:
input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png")
if allow_missing and not os.path.exists(input_mask_path):
return {}, None
input_mask, input_palette = load_ann_png(input_mask_path)
per_obj_input_mask = get_per_obj_mask(input_mask)
else:
per_obj_input_mask = {}
input_palette = None
# each object is a directory in "{object_id:%03d}" format
for object_name in os.listdir(os.path.join(input_mask_dir, video_name)):
object_id = int(object_name)
input_mask_path = os.path.join(
input_mask_dir, video_name, object_name, f"{frame_name}.png"
)
if allow_missing and not os.path.exists(input_mask_path):
continue
input_mask, input_palette = load_ann_png(input_mask_path)
per_obj_input_mask[object_id] = input_mask > 0
return per_obj_input_mask, input_palette
def save_masks_to_dir(
output_mask_dir,
video_name,
frame_name,
per_obj_output_mask,
height,
width,
per_obj_png_file,
output_palette,
):
"""Save masks to a directory as PNG files."""
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
if not per_obj_png_file:
output_mask = put_per_obj_mask(per_obj_output_mask, height, width)
output_mask_path = os.path.join(
output_mask_dir, video_name, f"{frame_name}.png"
)
save_ann_png(output_mask_path, output_mask, output_palette)
else:
for object_id, object_mask in per_obj_output_mask.items():
object_name = f"{object_id:03d}"
os.makedirs(
os.path.join(output_mask_dir, video_name, object_name),
exist_ok=True,
)
output_mask = object_mask.reshape(height, width).astype(np.uint8)
output_mask_path = os.path.join(
output_mask_dir, video_name, object_name, f"{frame_name}.png"
)
save_ann_png(output_mask_path, output_mask, output_palette)
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def vos_inference(
predictor,
base_video_dir,
input_mask_dir,
output_mask_dir,
video_name,
score_thresh=0.0,
use_all_masks=False,
per_obj_png_file=False,
):
"""Run VOS inference on a single video with the given predictor."""
# load the video frames and initialize the inference state on this video
video_dir = os.path.join(base_video_dir, video_name)
frame_names = [
os.path.splitext(p)[0]
for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(
video_path=video_dir, async_loading_frames=False
)
height = inference_state["video_height"]
width = inference_state["video_width"]
input_palette = None
# fetch mask inputs from input_mask_dir (either only mask for the first frame, or all available masks)
if not use_all_masks:
# use only the first video's ground-truth mask as the input mask
input_frame_inds = [0]
else:
# use all mask files available in the input_mask_dir as the input masks
if not per_obj_png_file:
input_frame_inds = [
idx
for idx, name in enumerate(frame_names)
if os.path.exists(
os.path.join(input_mask_dir, video_name, f"{name}.png")
)
]
else:
input_frame_inds = [
idx
for object_name in os.listdir(os.path.join(input_mask_dir, video_name))
for idx, name in enumerate(frame_names)
if os.path.exists(
os.path.join(input_mask_dir, video_name, object_name, f"{name}.png")
)
]
# check and make sure we got at least one input frame
if len(input_frame_inds) == 0:
raise RuntimeError(
f"In {video_name=}, got no input masks in {input_mask_dir=}. "
"Please make sure the input masks are available in the correct format."
)
input_frame_inds = sorted(set(input_frame_inds))
# add those input masks to SAM 2 inference state before propagation
object_ids_set = None
for input_frame_idx in input_frame_inds:
try:
per_obj_input_mask, input_palette = load_masks_from_dir(
input_mask_dir=input_mask_dir,
video_name=video_name,
frame_name=frame_names[input_frame_idx],
per_obj_png_file=per_obj_png_file,
)
except FileNotFoundError as e:
raise RuntimeError(
f"In {video_name=}, failed to load input mask for frame {input_frame_idx=}. "
"Please add the `--track_object_appearing_later_in_video` flag "
"for VOS datasets that don't have all objects to track appearing "
"in the first frame (such as LVOS or YouTube-VOS)."
) from e
# get the list of object ids to track from the first input frame
if object_ids_set is None:
object_ids_set = set(per_obj_input_mask)
for object_id, object_mask in per_obj_input_mask.items():
# check and make sure no new object ids appear only in later frames
if object_id not in object_ids_set:
raise RuntimeError(
f"In {video_name=}, got a new {object_id=} appearing only in a "
f"later {input_frame_idx=} (but not appearing in the first frame). "
"Please add the `--track_object_appearing_later_in_video` flag "
"for VOS datasets that don't have all objects to track appearing "
"in the first frame (such as LVOS or YouTube-VOS)."
)
predictor.add_new_mask(
inference_state=inference_state,
frame_idx=input_frame_idx,
obj_id=object_id,
mask=object_mask,
)
# check and make sure we have at least one object to track
if object_ids_set is None or len(object_ids_set) == 0:
raise RuntimeError(
f"In {video_name=}, got no object ids on {input_frame_inds=}. "
"Please add the `--track_object_appearing_later_in_video` flag "
"for VOS datasets that don't have all objects to track appearing "
"in the first frame (such as LVOS or YouTube-VOS)."
)
# run propagation throughout the video and collect the results in a dict
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
output_palette = input_palette or DAVIS_PALETTE
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
inference_state
):
per_obj_output_mask = {
out_obj_id: (out_mask_logits[i] > score_thresh).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
video_segments[out_frame_idx] = per_obj_output_mask
# write the output masks as palette PNG files to output_mask_dir
for out_frame_idx, per_obj_output_mask in video_segments.items():
save_masks_to_dir(
output_mask_dir=output_mask_dir,
video_name=video_name,
frame_name=frame_names[out_frame_idx],
per_obj_output_mask=per_obj_output_mask,
height=height,
width=width,
per_obj_png_file=per_obj_png_file,
output_palette=output_palette,
)
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def vos_separate_inference_per_object(
predictor,
base_video_dir,
input_mask_dir,
output_mask_dir,
video_name,
score_thresh=0.0,
use_all_masks=False,
per_obj_png_file=False,
):
"""
Run VOS inference on a single video with the given predictor.
Unlike `vos_inference`, this function run inference separately for each object
in a video, which could be applied to datasets like LVOS or YouTube-VOS that
don't have all objects to track appearing in the first frame (i.e. some objects
might appear only later in the video).
"""
# load the video frames and initialize the inference state on this video
video_dir = os.path.join(base_video_dir, video_name)
frame_names = [
os.path.splitext(p)[0]
for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(
video_path=video_dir, async_loading_frames=False
)
height = inference_state["video_height"]
width = inference_state["video_width"]
input_palette = None
# collect all the object ids and their input masks
inputs_per_object = defaultdict(dict)
for idx, name in enumerate(frame_names):
if per_obj_png_file or os.path.exists(
os.path.join(input_mask_dir, video_name, f"{name}.png")
):
per_obj_input_mask, input_palette = load_masks_from_dir(
input_mask_dir=input_mask_dir,
video_name=video_name,
frame_name=frame_names[idx],
per_obj_png_file=per_obj_png_file,
allow_missing=True,
)
for object_id, object_mask in per_obj_input_mask.items():
# skip empty masks
if not np.any(object_mask):
continue
# if `use_all_masks=False`, we only use the first mask for each object
if len(inputs_per_object[object_id]) > 0 and not use_all_masks:
continue
print(f"adding mask from frame {idx} as input for {object_id=}")
inputs_per_object[object_id][idx] = object_mask
# run inference separately for each object in the video
object_ids = sorted(inputs_per_object)
output_scores_per_object = defaultdict(dict)
for object_id in object_ids:
# add those input masks to SAM 2 inference state before propagation
input_frame_inds = sorted(inputs_per_object[object_id])
predictor.reset_state(inference_state)
for input_frame_idx in input_frame_inds:
predictor.add_new_mask(
inference_state=inference_state,
frame_idx=input_frame_idx,
obj_id=object_id,
mask=inputs_per_object[object_id][input_frame_idx],
)
# run propagation throughout the video and collect the results in a dict
for out_frame_idx, _, out_mask_logits in predictor.propagate_in_video(
inference_state,
start_frame_idx=min(input_frame_inds),
reverse=False,
):
obj_scores = out_mask_logits.cpu().numpy()
output_scores_per_object[object_id][out_frame_idx] = obj_scores
# post-processing: consolidate the per-object scores into per-frame masks
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
output_palette = input_palette or DAVIS_PALETTE
video_segments = {} # video_segments contains the per-frame segmentation results
for frame_idx in range(len(frame_names)):
scores = torch.full(
size=(len(object_ids), 1, height, width),
fill_value=-1024.0,
dtype=torch.float32,
)
for i, object_id in enumerate(object_ids):
if frame_idx in output_scores_per_object[object_id]:
scores[i] = torch.from_numpy(
output_scores_per_object[object_id][frame_idx]
)
if not per_obj_png_file:
scores = predictor._apply_non_overlapping_constraints(scores)
per_obj_output_mask = {
object_id: (scores[i] > score_thresh).cpu().numpy()
for i, object_id in enumerate(object_ids)
}
video_segments[frame_idx] = per_obj_output_mask
# write the output masks as palette PNG files to output_mask_dir
for frame_idx, per_obj_output_mask in video_segments.items():
save_masks_to_dir(
output_mask_dir=output_mask_dir,
video_name=video_name,
frame_name=frame_names[frame_idx],
per_obj_output_mask=per_obj_output_mask,
height=height,
width=width,
per_obj_png_file=per_obj_png_file,
output_palette=output_palette,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--sam2_cfg",
type=str,
default="configs/sam2.1/sam2.1_hiera_b+.yaml",
help="SAM 2 model configuration file",
)
parser.add_argument(
"--sam2_checkpoint",
type=str,
default="./checkpoints/sam2.1_hiera_base_plus.pt",
help="path to the SAM 2 model checkpoint",
)
parser.add_argument(
"--base_video_dir",
type=str,
required=True,
help="directory containing videos (as JPEG files) to run VOS prediction on",
)
parser.add_argument(
"--input_mask_dir",
type=str,
required=True,
help="directory containing input masks (as PNG files) of each video",
)
parser.add_argument(
"--video_list_file",
type=str,
default=None,
help="text file containing the list of video names to run VOS prediction on",
)
parser.add_argument(
"--output_mask_dir",
type=str,
required=True,
help="directory to save the output masks (as PNG files)",
)
parser.add_argument(
"--score_thresh",
type=float,
default=0.0,
help="threshold for the output mask logits (default: 0.0)",
)
parser.add_argument(
"--use_all_masks",
action="store_true",
help="whether to use all available PNG files in input_mask_dir "
"(default without this flag: just the first PNG file as input to the SAM 2 model; "
"usually we don't need this flag, since semi-supervised VOS evaluation usually takes input from the first frame only)",
)
parser.add_argument(
"--per_obj_png_file",
action="store_true",
help="whether use separate per-object PNG files for input and output masks "
"(default without this flag: all object masks are packed into a single PNG file on each frame following DAVIS format; "
"note that the SA-V dataset stores each object mask as an individual PNG file and requires this flag)",
)
parser.add_argument(
"--apply_postprocessing",
action="store_true",
help="whether to apply postprocessing (e.g. hole-filling) to the output masks "
"(we don't apply such post-processing in the SAM 2 model evaluation)",
)
parser.add_argument(
"--track_object_appearing_later_in_video",
action="store_true",
help="whether to track objects that appear later in the video (i.e. not on the first frame; "
"some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
)
parser.add_argument(
"--use_vos_optimized_video_predictor",
action="store_true",
help="whether to use vos optimized video predictor with all modules compiled",
)
args = parser.parse_args()
# if we use per-object PNG files, they could possibly overlap in inputs and outputs
hydra_overrides_extra = [
"++model.non_overlap_masks=" + ("false" if args.per_obj_png_file else "true")
]
predictor = build_sam2_video_predictor(
config_file=args.sam2_cfg,
ckpt_path=args.sam2_checkpoint,
apply_postprocessing=args.apply_postprocessing,
hydra_overrides_extra=hydra_overrides_extra,
vos_optimized=args.use_vos_optimized_video_predictor,
)
if args.use_all_masks:
print("using all available masks in input_mask_dir as input to the SAM 2 model")
else:
print(
"using only the first frame's mask in input_mask_dir as input to the SAM 2 model"
)
# if a video list file is provided, read the video names from the file
# (otherwise, we use all subdirectories in base_video_dir)
if args.video_list_file is not None:
with open(args.video_list_file, "r") as f:
video_names = [v.strip() for v in f.readlines()]
else:
video_names = [
p
for p in os.listdir(args.base_video_dir)
if os.path.isdir(os.path.join(args.base_video_dir, p))
]
print(f"running VOS prediction on {len(video_names)} videos:\n{video_names}")
for n_video, video_name in enumerate(video_names):
print(f"\n{n_video + 1}/{len(video_names)} - running on {video_name}")
if not args.track_object_appearing_later_in_video:
vos_inference(
predictor=predictor,
base_video_dir=args.base_video_dir,
input_mask_dir=args.input_mask_dir,
output_mask_dir=args.output_mask_dir,
video_name=video_name,
score_thresh=args.score_thresh,
use_all_masks=args.use_all_masks,
per_obj_png_file=args.per_obj_png_file,
)
else:
vos_separate_inference_per_object(
predictor=predictor,
base_video_dir=args.base_video_dir,
input_mask_dir=args.input_mask_dir,
output_mask_dir=args.output_mask_dir,
video_name=video_name,
score_thresh=args.score_thresh,
use_all_masks=args.use_all_masks,
per_obj_png_file=args.per_obj_png_file,
)
print(
f"completed VOS prediction on {len(video_names)} videos -- "
f"output masks saved to {args.output_mask_dir}"
)
if __name__ == "__main__":
main()
# Training Code for SAM 2
This folder contains the training code for SAM 2, a foundation model for promptable visual segmentation in images and videos.
The code allows users to train and fine-tune SAM 2 on their own datasets (image, video, or both).
## Structure
The training code is organized into the following subfolders:
* `dataset`: This folder contains image and video dataset and dataloader classes as well as their transforms.
* `model`: This folder contains the main model class (`SAM2Train`) for training/fine-tuning. `SAM2Train` inherits from `SAM2Base` model and provides functions to enable training or fine-tuning SAM 2. It also accepts all training-time parameters used for simulating user prompts (e.g. iterative point sampling).
* `utils`: This folder contains training utils such as loggers and distributed training utils.
* `scripts`: This folder contains the script to extract the frames of SA-V dataset to be used in training.
* `loss_fns.py`: This file has the main loss class (`MultiStepMultiMasksAndIous`) used for training.
* `optimizer.py`: This file contains all optimizer utils that support arbitrary schedulers.
* `trainer.py`: This file contains the `Trainer` class that accepts all the `Hydra` configurable modules (model, optimizer, datasets, etc..) and implements the main train/eval loop.
* `train.py`: This script is used to launch training jobs. It supports single and multi-node jobs. For usage, please check the [Getting Started](README.md#getting-started) section or run `python training/train.py -h`
## Getting Started
To get started with the training code, we provide a simple example to fine-tune our checkpoints on [MOSE](https://henghuiding.github.io/MOSE/) dataset, which can be extended to your custom datasets.
#### Requirements:
- We assume training on A100 GPUs with **80 GB** of memory.
- Download the MOSE dataset using one of the provided links from [here](https://github.com/henghuiding/MOSE-api?tab=readme-ov-file#download).
#### Steps to fine-tune on MOSE:
- Install the packages required for training by running `pip install -e ".[dev]"`.
- Set the paths for MOSE dataset in `configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml`.
```yaml
dataset:
# PATHS to Dataset
img_folder: null # PATH to MOSE JPEGImages folder
gt_folder: null # PATH to MOSE Annotations folder
file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training
```
- To fine-tune the base model on MOSE using 8 GPUs, run
```python
python training/train.py \
-c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \
--use-cluster 0 \
--num-gpus 8
```
We also support multi-node training on a cluster using [SLURM](https://slurm.schedmd.com/documentation.html), for example, you can train on 2 nodes by running
```python
python training/train.py \
-c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \
--use-cluster 1 \
--num-gpus 8 \
--num-nodes 2
--partition $PARTITION \
--qos $QOS \
--account $ACCOUNT
```
where partition, qos, and account are optional and depend on your SLURM configuration.
By default, the checkpoint and logs will be saved under `sam2_logs` directory in the root of the repo. Alternatively, you can set the experiment log directory in the config file as follows:
```yaml
experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
```
The training losses can be monitored using `tensorboard` logs stored under `tensorboard/` in the experiment log directory. We also provide a sample validation [split]( ../training/assets/MOSE_sample_val_list.txt) for evaluation purposes. To generate predictions, follow this [guide](../tools/README.md) on how to use our `vos_inference.py` script. After generating the predictions, you can run the `sav_evaluator.py` as detailed [here](../sav_dataset/README.md#sa-v-val-and-test-evaluation). The expected MOSE J&F after fine-tuning the Base plus model is 79.4.
After training/fine-tuning, you can then use the new checkpoint (saved in `checkpoints/` in the experiment log directory) similar to SAM 2 released checkpoints (as illustrated [here](../README.md#image-prediction)).
## Training on images and videos
The code supports training on images and videos (similar to how SAM 2 is trained). We provide classes for loading SA-1B as a sample image dataset, SA-V as a sample video dataset, as well as any DAVIS-style video dataset (e.g. MOSE). Note that to train on SA-V, you must first extract all videos to JPEG frames using the provided extraction [script](./scripts/sav_frame_extraction_submitit.py). Below is an example of how to setup the datasets in your config to train on a mix of image and video datasets:
```yaml
data:
train:
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
phases_per_epoch: ${phases_per_epoch} # Chunks a single epoch into smaller phases
batch_sizes: # List of batch sizes corresponding to each dataset
- ${bs1} # Batch size of dataset 1
- ${bs2} # Batch size of dataset 2
datasets:
# SA1B as an example of an image dataset
- _target_: training.dataset.vos_dataset.VOSDataset
training: true
video_dataset:
_target_: training.dataset.vos_raw_dataset.SA1BRawDataset
img_folder: ${path_to_img_folder}
gt_folder: ${path_to_gt_folder}
file_list_txt: ${path_to_train_filelist} # Optional
sampler:
_target_: training.dataset.vos_sampler.RandomUniformSampler
num_frames: 1
max_num_objects: ${max_num_objects_per_image}
transforms: ${image_transforms}
# SA-V as an example of a video dataset
- _target_: training.dataset.vos_dataset.VOSDataset
training: true
video_dataset:
_target_: training.dataset.vos_raw_dataset.JSONRawDataset
img_folder: ${path_to_img_folder}
gt_folder: ${path_to_gt_folder}
file_list_txt: ${path_to_train_filelist} # Optional
ann_every: 4
sampler:
_target_: training.dataset.vos_sampler.RandomUniformSampler
num_frames: 8 # Number of frames per video
max_num_objects: ${max_num_objects_per_video}
reverse_time_prob: ${reverse_time_prob} # probability to reverse video
transforms: ${video_transforms}
shuffle: True
num_workers: ${num_train_workers}
pin_memory: True
drop_last: True
collate_fn:
_target_: training.utils.data_utils.collate_fn
_partial_: true
dict_key: all
```
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
28191f94
662487fe
80906bf9
7e704f2e
efa25913
b6f03bd9
6834d249
5a723c30
07779415
4ce088c6
199995b5
54273925
4fa342f5
110da3cf
65856fa0
46705bb3
d869a3cf
555aa049
8f01fb2c
37b07a28
5e80b3dd
ba0e4dd4
6f5144b6
acec8407
93723f88
c7c7528c
97f58761
e71f9faa
e64c13dc
8830d59d
0e4aeed9
63437cf3
95215aa1
255f86ef
dc54aab2
327cd258
198021ad
c690220c
d25ff89d
7875b874
4fa6d325
9fc933f6
4d8baafe
55ae6921
6a3bc149
89f8163f
2d65d2ac
dba172b1
a14de179
4017d1b3
52ddf44c
3ba93641
34a5f964
da7dee28
872b76de
1dc12eca
265a69f4
86a2b59f
51e5ca25
ddf80bcd
6786602e
4fa28c89
f56942e9
2184bb93
d883e976
bfe1469e
bc4e7b11
1c80acb0
2b0e34d3
56b9ce41
15f0b0cd
cc5d0dd1
1b7eada8
7286b176
0ab42ab1
adb82dc9
c060b1e6
3da63bd5
5488796e
d7066e20
aab5ed11
17f66311
24df9789
208fa934
7ce2c865
debe4249
4c56bbea
149dbae2
beb693c9
49eb0315
e7ad4717
4e016d5a
95e24093
07b5d86c
80701b6c
337dfa1e
b624a46e
3f849de8
5db21df2
47891b4c
a966d7fd
013103f6
da5e4bc5
ba9ea03d
526195de
57f3a53e
b3aff7f8
26048547
bb7ee856
aef0d049
e35a8262
57ad022e
f45d3823
e5e9eb29
39cc637e
a4fc4f17
dd5a4739
bbe97d18
33602f6b
9061dac9
23454d80
a20baeec
794f01d4
02de2f2a
055fca57
a69df343
e307510e
d07ad1be
1fc5e086
db6533a5
fe9706b7
87e32230
8ba58e4c
561f6380
2ab9ba0f
86571569
756cc6c9
aa185af5
c6d7f94b
7f54c579
71f4b40e
4190c83a
fef0aba4
2f7c71bb
e4b6f2ef
76adaeea
11cdeb64
733f2a02
e50dbddb
f643141f
d2e75e95
84559bc3
7ade3068
e69db797
0b787263
57895315
d7969c29
62529cd4
203733e7
48fd97a6
723fd024
849f0efb
aafea009
dd4eb8f1
d18554ae
f3c0f0cf
90fe55b9
b0ffaf3b
e79ecd47
d670ce7b
56a5643a
90ff1d09
1fb378d9
57014c7d
994ed763
5bc7ea74
e99bd793
cbb66185
5f3fcff6
05ed1023
85efa9e3
652929ce
905d8740
a6fcde01
0fdf67f7
a5cf4c8d
e1c48bdd
782551f7
6acd353f
c30641cf
81d12756
51befc31
9d5ab5ca
d262b7e4
2cd705a9
f7360199
d3f3bf9d
028f6f64
94767cb4
3a739934
72433603
ec66879d
6149becc
5845c157
c5082b3c
f89b54d0
f3ada126
409dcb8a
4411fdee
eb93ed20
9cb1ba0e
b8e1ec26
7edd8b4f
5e9412c0
2744f35a
dafeb75e
f3f072f2
6f1df574
5a064706
89c76ac4
a6adef89
76303516
dbd67417
a53ef3fa
10552818
ac7deb19
2d403c59
55c157f1
214aeac3
a9f5e251
d7807996
d1dba33b
1367e367
44476e77
0644075b
eda37457
f2de4198
9a4ce701
46e00caf
2ae75f99
cd49fb99
4e4483e7
a0669957
a6f0d882
9ce1d54a
1fc2314b
21f363b3
32ecef67
70bcaf68
115348f9
60827ada
a218e951
6d30d5ac
6da17988
f22c39ce
5825f0e0
f415f9ad
0d4feda2
832fc243
414ca58b
a92390a0
ddd383cc
43dc67f7
962ae0e2
6dd74e7b
2bcd6c3b
b394847f
637fd121
d46e771b
f6bfc699
63f138de
932ad0a6
2080824a
52fa9174
843d3bf7
f3431885
5c20c48a
134a2ab0
2ea465de
f6786ab5
2bf49664
a49ce97b
6a50e93a
a7c21e95
616ad8ec
0a8d7b41
b0c90527
2d893fb7
19310598
7744dc51
4539b907
9d299f60
e495537a
0b02886a
f4c4a2ca
e957b2b5
e6f3bf07
258944c8
54364322
ebb77f95
0af03282
cbdbc6c3
494ecef0
ee91f783
9698f06e
11e16068
b942ce0a
423a50e6
fb16e746
9c88ae45
8620c024
d3af3c85
780a25de
e569a15f
c4f9f19e
1106f3a7
d37e29a7
e53611da
fdb2e432
18ad3117
6fcd426d
3bfa8379
3b19c5c3
ff1142df
cd182615
b60ea255
b3f5d019
6dc5e55d
103166c7
37af9ac1
ad1881d1
731149b3
90e3338a
6aa0b6f2
a25316a3
dc8679e0
571fb490
80afed16
983a551b
a58578e5
2bc0bba4
1143b3fe
fdd8dd49
7fe2bf77
890ef032
8466eeb2
c791ddbb
631b82bd
78bf9b51
a99df45f
2bdb692f
e89b1501
4e6aa1e8
e5665030
fe21fd5c
635577d5
4414cd3a
03c99e83
ff041cd1
c33adbc2
a988ec74
576031e0
03c21af7
79b25f4b
bbc485d6
d36d5a0d
efdab888
b20e6781
81fdc526
e1c26a53
7c6d3504
52a04667
f22e34d4
bb936ead
13f0606c
d2abc61e
af509e8f
bea1c144
e15e4de8
e727099f
b30744df
ffb6a2e4
0d31d3a6
a23048fe
7d452630
6c736334
046ed4f4
94f4c2aa
c290cfd3
f7203226
2fdae3c5
7c78e351
02b72b8d
2d22d3be
ba28d02e
197f6587
43199a98
b563b04f
9293b755
9cef7489
d156b96f
15e9161e
6d094cd5
0d876a65
c818d30a
8094b12b
a4a8e24b
14655f54
11c14893
8a48f62a
7f3d9c22
d952481c
03e0f9b8
28980657
6a0b5563
5879983c
37549a79
4a7162bd
7a6aa1ef
0dc1b78c
f6dba17b
1dba51af
b2f4d608
e2e6f421
464066da
5d24e4ea
1e75004d
a02ed92c
673adbcc
c2a0c0fd
85addee5
54b8f502
f5d2d8d3
a19507e1
803e1756
0d1fe009
5968c2d8
b926e1ad
a9162e14
ae470d2b
bd731802
68c879f2
21fe05d9
c1ed21d0
831498e4
cc45a7f2
cb170015
59750be4
30d1cb6b
03e5f069
106d33db
3f003746
3e5ad020
8bc5a91c
64b89eb5
bfd28682
f8687b9a
7bbf38ee
d6d92b30
ceaa6c65
677c8ed7
dc33acf8
cfd1de31
e5be4781
85585220
5d2316f6
dd3f4a07
34535f5f
3ae0bc5d
f521e3c5
74c2284f
12a42fd9
61403519
88cd32f3
662a1846
825a1944
cf376cf1
8465d99c
61a2e246
62d44645
103b3ca8
c7e745ed
4ed71139
230c2edf
529c6889
9e509c0d
54b9dea2
a8934c0d
29cffe2f
48017512
c9f7f69d
ce691ee6
21c89360
3b97c07b
ebd82d35
2895bb8b
7043c5c1
85d694d7
88fd7507
18d8931e
aa718745
89b671bb
0d8d30ae
26163977
a6121689
1589579d
159789c4
f5ca8271
fcc16740
3158be0b
860fc1f7
3f54a330
82f24ce7
069f6a2a
2fa9c523
c9f1d87f
efe9cbca
8f969ea5
4f5db794
62c501f8
2d3b0320
c99637f0
0f3b1fcb
6e4ee861
e0d9aff0
230ddb91
e14d1f96
c83aa6a1
eabdf66a
6783a303
81659eb2
ce954bd7
9a48c0c9
0ab807b4
f0617f71
fe86f2f8
61d80e22
e4b6d2a0
ac093040
0e05fabe
d0b507c3
3d828137
c4fa0bab
f7783321
ec27366a
404e4c58
073baf48
0f685e01
b0e98fdd
b4891f7f
a46b7b77
ee059f99
3c87888e
8d23ddcc
2d8d7d35
5680be79
fc79c03e
20660b72
53f67585
90956534
7e709e2d
dae93f5c
54b9dbba
cc41ba05
1e207fe0
a9c6abf2
35e0ca09
e3dcd186
1b8bb699
92162474
cdad6812
50b91533
570215ac
6042d64a
b6e2c041
08746283
7a056996
b8651773
adf443e1
6a6e0e3b
886ed981
c1d57fea
43030c4c
7ebfbf57
0770ad03
e85301d5
31ac3d98
acaef45e
8f415dd1
fe2dc281
2c0b9d99
8e24501e
911ec4ad
8036b58e
c3b350b9
b6cadd11
a3a80cf7
88ab50cd
59c755a8
1339321a
91b2f707
97b0811e
1da33959
31b09833
c1a40349
708098a9
1f220f98
999e07cb
0b5e5d29
94c63453
b826d642
a598602d
4c83eab8
2efd5e50
6ec5da3a
9fcd95eb
9a2c6b5b
c205a718
e638e950
cb43141c
494dd91d
c4957274
4975a81d
a1f4c54d
51e6fafa
514490e5
b0d09e6a
c6726eb8
06772c9a
5a65ffd7
3657c62b
03012cfd
529df209
f1c38e66
ab417352
118a067e
8957514f
22e8b380
3b1a4616
a4457543
57c9f6e0
e362c16b
0f809e41
857e375e
9cff25e3
d754fb65
6ad44b86
051052d8
a4564b94
f68507d0
80a7cf7b
ad8cd1e0
60b19cd3
274fe944
f06632aa
628a337b
92c96c05
87fc565c
6f6e6c37
228a0234
6487110a
aa911a8e
40c47fa3
9606508b
6ba9e61f
c8c1d5a9
cf01df5b
9421b9ad
006e6b64
1c28e081
06273084
8925e11b
b46c822b
00501424
cfd946b2
2e92a7dc
1c5f5bb6
1d29944c
8248698e
19247506
1eac1aff
ee9caa47
4a41cbf8
d97c9309
4ca87c14
9707f1e3
8bb9a221
6605e67d
95cf72d7
1c6fb814
033130b2
4344808d
5f14e5d2
a810399b
e325a6d4
7014ddf4
725d4bfb
790285e8
1a6a731f
fbfb6e30
0d4d88f6
80ce18a4
572495b7
4b44dc50
95dce33c
4a6fb202
3142014e
a3c56751
96b2a414
c4aa176c
fd1e394f
93f0f509
f494e9fa
bfa42a75
db5319c7
aa92e070
81220a93
e4a72496
fc467bf1
5397b01d
1dc0c9a0
f6f8b4a6
53dc7db4
8ef303eb
62ca45c9
e9d3465e
3784e3f6
8c934e67
5ba84e3f
30e41f1e
61cf0ec8
e93e8f01
fc6086dd
a95f0aea
33a04ef2
6f295adb
d2aa8c66
724cc810
d8623d26
8d0d641a
4bda7a76
38030c69
56199c41
d2f4b9e2
a7b8ac96
64044df1
fd1078cc
0165667b
16e1cca7
915f0d9a
eeaaa67e
378430d5
a84c60e6
b4ae36cc
2a3a0571
13e6df75
aa348c45
59d7a11d
68954daf
d6f883c6
f28b429a
32dc49d4
ccf14ee0
7d512591
9bdabdb2
ed878d94
54eda06d
132561ee
3c4b6736
0367af42
531c1c36
843d8f25
333bdbdc
c3c21268
07b00746
c7fe0584
49fc9f2e
9ed4317a
d29991b4
98b0033d
f0b922bf
89fe6899
58264713
2f49220a
6ff85ca5
4b96b2c8
a42f54f5
aa425600
22fdee40
dde85a9d
3722f6fe
e7529cbc
5ae23f9f
cc32235b
730bc486
b12701b7
a96b3010
16130bd3
2c713560
f7935d24
a7eb6616
0d6e7177
100edaef
0442a954
60f4fa43
37bf7edf
76b18413
ab0646a9
c575434d
1e356390
5416fbb7
df7cf932
269872de
9033b607
c2e88575
932542cd
23e046fb
3d08dadd
7999adc5
ed81c485
3bd7facd
1feae28e
8d72533b
6a8d35d6
65308bdc
7f0b7662
98290486
fee3371f
c463c7e5
faf7d852
75c34dc5
96a6722e
e5605136
851bc5d9
15c41c4b
6a39e104
5fbff256
0e7001dd
5411113f
3ea2f7f2
242b74b1
87727003
ec6dd0e9
980baf58
9d0b7bf1
9113c9d4
5ebef6bd
a5f70ce7
b0240233
06ad78e0
8745edd0
d8e8d984
ac32a655
38568758
d48c552d
0b27d5f7
c65d0736
800e3c14
d37a5857
bcebc660
d3ab52cc
405e3ee7
e33cddc9
b0197182
89fd5681
9e192417
8554c402
aae923b8
31af515d
75b26f88
60471744
460945aa
c0fe8e1a
1731babb
2e85e35d
f9c20062
115da184
ddfa88c7
359003f8
dfa99126
bf04814f
f407a414
e18723c4
0a7a3629
c07ab37e
1251a1c9
4d09d22a
5984ed74
34504f63
ced51047
08ff419c
d942e98c
2697f864
3b671a61
72a2f7e2
48e7cafe
6adad2f7
18840617
1e44f47e
36cc4055
8c494902
2982de7a
6a428397
c4a0ecfb
231d6945
fe470104
f93e1bd0
bd18bc5a
7bd70d93
8f81a0ee
db78e7a1
7593caea
86d5b29b
5457b298
0d967fd1
62372d4c
68259db3
f0944ea2
7b017dbf
bcb6e338
03692b14
f7d36a47
1ca2531a
6728528d
1fc0e6a8
0ba9c5ad
a386eaa2
b0c5459f
1d64aff3
b97d4f1a
b3745d91
c461003e
910bf878
ae42601c
8d2ddeff
aaecaa39
250b5034
edb11192
7bfe9b57
6d533759
51586b36
a38d648a
8fdb48e5
6075d6b0
3588ea03
bc844942
398d41f5
660e3b70
0b99f522
f169fd1b
7bfa2ab5
ab461319
25153e58
002b4dce
a2df1bee
550a7357
b604f2dd
2f477d05
bdf9eb5a
857ddc6e
c8f0fd41
6df96f15
e147ab26
788da8e8
02221fb0
d1d95c61
a3f0cb28
3a6e6ace
67c2909a
220382ab
eaed776d
aff08a61
b99d1bd6
9d9ae988
34ccea00
41dae436
18513251
ad57acd1
67f110fc
3f09f5c9
25ef7d43
12a5d0d7
3ff48b8b
26ed56e6
c047a092
bb8639e1
8788747f
584838d4
f8e5f837
657242e8
cb8eedf4
74a917f1
578f71da
c9b27125
22e1f53c
f40145c2
4795259b
3f313a2f
c9012bf6
22167a50
6e7f9437
ef51a724
356e0fcb
d3ea999d
08a5c662
85aa3b0e
579fadec
7bc95dc2
c097af8e
f01d8b9f
80fb79c6
ea65e6b7
29ff29f6
9e1f739d
b7fb59c9
e2160f17
0be33bc1
e96b9b04
b1affe79
c4f4b2e2
f4c8ffb1
6a009e50
a8828854
2786f841
a64e724c
5f54d077
7040385d
6e0f0ecc
f33d3c15
8108b358
46a502de
1e0fb02a
ddbdfa32
e7b34ab6
c9080ed1
395224b3
33f9ab47
c245ecda
c28d81a9
37303a3b
6380dd6f
2fb5a55b
83b7c53c
41c8d0d2
3aab2d13
dc7d21fb
86a88668
37bb38fe
ab6413a8
bbe585b2
a0ca072a
9d5940d2
ddb1d0b1
a946317a
988b29a4
89dc0432
5df8490d
5e167efa
50a86faa
fe6a535a
a9f8b8b4
6e2dce1b
d0696759
c09da3b2
f07dd347
67408899
406165ff
a4a9d03d
9b5f0f47
5f3e8022
1d7a23e0
25af2eeb
82a3db34
c9351029
6c93d44c
f088ad1c
9ee59f51
b5276b3f
ca74a924
781af187
fa3e0b85
b898c99e
1ca51f06
5a92a0c1
138c81fe
d0722d0f
05a7d84d
e18f1dea
799a2d61
8276e558
f0ba8748
ce733e8a
2f9d0911
58f24fa4
66a25278
3135d31d
4b9223ee
bdd5e6b3
ddbebec1
8dbebbd9
3020b38f
e607450d
724a5d1c
91b754c5
2e85e790
3a407bd9
fd137178
a304029b
4023fc77
440d5072
2eb73c7c
164a7305
b33ade7c
277ad883
b0f7e75c
74107936
83924bdb
b72beb78
86c01d64
f6f441eb
23b9a3ea
80b73f1a
93c6411d
1e95ef5e
800b5eac
9519832a
ae043406
b06a902e
1dbca5cc
571f88a1
b1faf52b
45572497
8d016cdb
f92cdae8
316931f8
f9884439
e1b7f212
e23c6392
ccfae073
5aa1efda
74f0687c
eaff3301
b6520a94
c5398714
15e7e4d1
0fc00006
8cf49218
3a8ddc0a
e7e2a0b9
eec4c008
8d73085e
77e246da
00e92ab4
f76f6cf9
19801183
233406ef
b80e028c
342c0b2a
a2768c47
99350a74
adbd400b
f3978ade
b87a4f6c
fa95a6a2
6dff20c9
935b5ad8
dbbbb401
1b6472c1
9c0e6331
04ae7a6b
4c94e4f3
90cb46cb
2831ecf5
ff77a145
79af6097
ba61a719
abcb7665
7e87750e
c4c7bc5d
3a670b81
3d9a7023
82667d52
a4587f62
ca619b7f
7c5462f5
bda5c60d
e6e48ac8
405c6000
7981f344
f7375ab3
bb467ff9
cfc68a82
e417a6d8
1a6177c1
7b75dace
b1af350d
484d48a3
1f805416
7416ab4e
1291276c
9e85179b
5a74660c
7e6d00df
01e3cec8
ee2c0688
f6de8226
a217538c
b432c3ef
49e5ff4e
035359e5
8ae8e7ed
2da12766
cac39070
115adda4
1a2872dc
fac3378e
294e7bf8
a1a4991f
c062f4d7
72b2b77d
158062aa
9ae447a7
a7b05677
fdfd5d56
eac1a9e6
a5905593
59992293
84298fae
f708e55f
093d3d93
75d26197
924f5d88
3184a7ec
b454fdbc
2d9101b8
ae70fb7c
4385b2c4
63b37343
0b4b662c
2883ae72
ffcab778
0f96e2d7
897066e3
f23e98ad
797a7b7e
2fc476f9
32e5d721
5bad0bab
267bfd6c
0a43a414
56c56ca9
9a1146b3
c6ad7aaf
78a1f4b1
fc455e73
072e7b3f
77ccb57d
a76ee415
8cdcfc17
5d518b42
376dd830
0e843fc8
2af0e766
2bd4e845
de2f2a6a
ade9ee91
001ca3cb
fc4c1c67
8ef55579
b84ce852
4cc8528a
767ffaaa
112a2ef0
a338c8aa
cbd144f5
5ff72128
86a949e2
9f2323ac
1fab1d1c
75924351
ef55817b
02deca50
4d979d99
4d65f873
28470fa0
0d1575fe
06ea172e
29a6ddc2
797f1bec
780e7a99
b9ed5b44
02a236b4
607d8ff5
af5666b2
0558d0ed
a938c6b2
103df575
77110e80
739e5a07
6763a576
06ebc138
ba4b3b09
b35cc2f3
4e0597a0
5949ee84
5348d547
323c4236
b3b51117
55727ddd
ab2714f3
d2878895
c0734cb3
94f7c53e
2a2745e5
442ffb54
3592425a
50ae03b0
5f150435
3067f9fa
9ffb2818
adeaf5aa
31caacec
1cd99b86
aa22f9d0
8fa50320
e6348d2c
42ff84a5
8c8b7913
c96adcbc
495be321
db735509
ee113fc4
a678cdab
c409ca4d
68d2b259
592b4dee
4e2b4dc7
eb4d26e1
2009a00f
bec5c89d
67191f24
a3e85b4b
da7080cd
80d978e9
36dcb93f
a41e8c44
12fdc864
46d140ea
657c9dd9
a86f84ee
90c1c43d
33015509
afc7664d
23df06e1
291d4799
0ab75563
251bf059
bcefdcc4
ce9a2796
94d3403a
8f2e04bc
f9cda066
9dfa2cc5
66924c91
e765a09e
15654ee1
48e0bd39
ee095221
2463609b
544d0d1f
51b8c2e1
d321dde4
4cb11a5f
d7058a0d
37af282a
fabae187
7be91184
181ec185
2d16ceeb
b56be4b1
6699eff0
79acac96
d61c4665
0c13e1e7
100f6ecf
71217dfc
82df0888
4c42c747
c9fdf703
d2efeb4b
69ed9d14
64914fb6
255bedbc
4ea934d8
a034feb2
e4f4ddae
e36a3026
c1489591
111bb373
e1d9fb32
93e22d48
c1ec4b26
d9638e69
60ab04c5
cfe7773a
62132822
2f5fb2a3
7bdd197d
033333fd
130fcdbe
12e509c2
67138c33
6f90cc5f
4e3020fe
bbdd8bb7
b399ccdb
fecd10d2
2e0967f7
f509054f
792c6ff7
48e2afc5
d904c048
111e0a5c
b83024e2
e6a7b79c
bdc5ccf7
b8146d00
9d394f1a
645b84f9
95ab2d0f
e6f8a31d
b4f876fb
dc2c570d
3afd02d7
5c80c82c
b1b32ddd
9f25fc61
ba538072
f8916fef
43c04ad2
a658e949
2861dd53
f6e40aba
09d305d1
aac33bff
8d9d4c08
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from typing import Callable, Iterable, List, Optional, Sequence
import torch
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset
from torch.utils.data.distributed import DistributedSampler
class MixedDataLoader:
def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor):
"""
Args:
dataloaders (List[DataLoader]): List of DataLoaders to be mixed.
mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from
"""
assert len(dataloaders) == mixing_prob.shape[0]
self.dataloaders = dataloaders
self.mixing_prob = mixing_prob
# Iterator state
self._iter_dls = None
self._iter_mixing_prob = None
self.random_generator = torch.Generator()
def __len__(self):
return sum([len(d) for d in self.dataloaders])
def __iter__(self):
# Synchronize dataloader seeds
self.random_generator.manual_seed(42)
self._iter_dls = [iter(loader) for loader in self.dataloaders]
self._iter_mixing_prob = self.mixing_prob.clone()
return self
def __next__(self):
"""
Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted.
"""
if self._iter_dls is None:
raise TypeError(f"{type(self).__name__} object is not an iterator")
while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob.
dataset_idx = self._iter_mixing_prob.multinomial(
1, generator=self.random_generator
).item()
try:
item = next(self._iter_dls[dataset_idx])
return item
except StopIteration:
# No more iterations for this dataset, set it's mixing probability to zero and try again.
self._iter_mixing_prob[dataset_idx] = 0
except Exception as e:
# log and raise any other unexpected error.
logging.error(e)
raise e
# Exhausted all iterators
raise StopIteration
class TorchTrainMixedDataset:
def __init__(
self,
datasets: List[Dataset],
batch_sizes: List[int],
num_workers: int,
shuffle: bool,
pin_memory: bool,
drop_last: bool,
collate_fn: Optional[Callable] = None,
worker_init_fn: Optional[Callable] = None,
phases_per_epoch: int = 1,
dataset_prob: Optional[List[float]] = None,
) -> None:
"""
Args:
datasets (List[Dataset]): List of Datasets to be mixed.
batch_sizes (List[int]): Batch sizes for each dataset in the list.
num_workers (int): Number of workers per dataloader.
shuffle (bool): Whether or not to shuffle data.
pin_memory (bool): If True, use pinned memory when loading tensors from disk.
drop_last (bool): Whether or not to drop the last batch of data.
collate_fn (Callable): Function to merge a list of samples into a mini-batch.
worker_init_fn (Callable): Function to init each dataloader worker.
phases_per_epoch (int): Number of phases per epoch.
dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0
"""
self.datasets = datasets
self.batch_sizes = batch_sizes
self.num_workers = num_workers
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
self.collate_fn = collate_fn
self.worker_init_fn = worker_init_fn
assert len(self.datasets) > 0
for dataset in self.datasets:
assert not isinstance(dataset, IterableDataset), "Not supported"
# `RepeatFactorWrapper` requires calling set_epoch first to get its length
self._set_dataset_epoch(dataset, 0)
self.phases_per_epoch = phases_per_epoch
self.chunks = [None] * len(datasets)
if dataset_prob is None:
# If not provided, assign each dataset a probability proportional to its length.
dataset_lens = [
(math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs))
for d, bs in zip(datasets, batch_sizes)
]
total_len = sum(dataset_lens)
dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens])
else:
assert len(dataset_prob) == len(datasets)
dataset_prob = torch.tensor(dataset_prob)
logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}")
assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0"
self.dataset_prob = dataset_prob
def _set_dataset_epoch(self, dataset, epoch: int) -> None:
if hasattr(dataset, "epoch"):
dataset.epoch = epoch
if hasattr(dataset, "set_epoch"):
dataset.set_epoch(epoch)
def get_loader(self, epoch) -> Iterable:
dataloaders = []
for d_idx, (dataset, batch_size) in enumerate(
zip(self.datasets, self.batch_sizes)
):
if self.phases_per_epoch > 1:
# Major epoch that looops over entire dataset
# len(main_epoch) == phases_per_epoch * len(epoch)
main_epoch = epoch // self.phases_per_epoch
# Phase with in the main epoch
local_phase = epoch % self.phases_per_epoch
# Start of new data-epoch or job is resumed after preemtion.
if local_phase == 0 or self.chunks[d_idx] is None:
# set seed for dataset epoch
# If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.
self._set_dataset_epoch(dataset, main_epoch)
# Separate random generator for subset sampling
g = torch.Generator()
g.manual_seed(main_epoch)
self.chunks[d_idx] = torch.chunk(
torch.randperm(len(dataset), generator=g),
self.phases_per_epoch,
)
dataset = Subset(dataset, self.chunks[d_idx][local_phase])
else:
self._set_dataset_epoch(dataset, epoch)
sampler = DistributedSampler(dataset, shuffle=self.shuffle)
sampler.set_epoch(epoch)
batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)
dataloaders.append(
DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
batch_sampler=batch_sampler,
collate_fn=self.collate_fn,
worker_init_fn=self.worker_init_fn,
)
)
return MixedDataLoader(dataloaders, self.dataset_prob)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment