Commit ce0e5303 authored by bailuo's avatar bailuo
Browse files

init

parents
Pipeline #2003 failed with stages
in 0 seconds
# ==========================================================
# Modified from mmcv
# ==========================================================
import json
import pickle
from abc import ABCMeta, abstractmethod
from pathlib import Path
import yaml
try:
from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
from yaml import Loader, Dumper
# ===========================
# Rigister handler
# ===========================
class BaseFileHandler(metaclass=ABCMeta):
@abstractmethod
def load_from_fileobj(self, file, **kwargs):
pass
@abstractmethod
def dump_to_fileobj(self, obj, file, **kwargs):
pass
@abstractmethod
def dump_to_str(self, obj, **kwargs):
pass
def load_from_path(self, filepath, mode="r", **kwargs):
with open(filepath, mode) as f:
return self.load_from_fileobj(f, **kwargs)
def dump_to_path(self, obj, filepath, mode="w", **kwargs):
with open(filepath, mode) as f:
self.dump_to_fileobj(obj, f, **kwargs)
class JsonHandler(BaseFileHandler):
def load_from_fileobj(self, file):
return json.load(file)
def dump_to_fileobj(self, obj, file, **kwargs):
json.dump(obj, file, **kwargs)
def dump_to_str(self, obj, **kwargs):
return json.dumps(obj, **kwargs)
class PickleHandler(BaseFileHandler):
def load_from_fileobj(self, file, **kwargs):
return pickle.load(file, **kwargs)
def load_from_path(self, filepath, **kwargs):
return super(PickleHandler, self).load_from_path(filepath, mode="rb", **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault("protocol", 2)
return pickle.dumps(obj, **kwargs)
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault("protocol", 2)
pickle.dump(obj, file, **kwargs)
def dump_to_path(self, obj, filepath, **kwargs):
super(PickleHandler, self).dump_to_path(obj, filepath, mode="wb", **kwargs)
class YamlHandler(BaseFileHandler):
def load_from_fileobj(self, file, **kwargs):
kwargs.setdefault("Loader", Loader)
return yaml.load(file, **kwargs)
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault("Dumper", Dumper)
yaml.dump(obj, file, **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault("Dumper", Dumper)
return yaml.dump(obj, **kwargs)
file_handlers = {
"json": JsonHandler(),
"yaml": YamlHandler(),
"yml": YamlHandler(),
"pickle": PickleHandler(),
"pkl": PickleHandler(),
}
# ===========================
# load and dump
# ===========================
def is_str(x):
"""Whether the input is an string instance.
Note: This method is deprecated since python 2 is no longer supported.
"""
return isinstance(x, str)
def slload(file, file_format=None, **kwargs):
"""Load data from json/yaml/pickle files.
This method provides a unified api for loading data from serialized files.
Args:
file (str or :obj:`Path` or file-like object): Filename or a file-like
object.
file_format (str, optional): If not specified, the file format will be
inferred from the file extension, otherwise use the specified one.
Currently supported formats include "json", "yaml/yml" and
"pickle/pkl".
Returns:
The content from the file.
"""
if isinstance(file, Path):
file = str(file)
if file_format is None and is_str(file):
file_format = file.split(".")[-1]
if file_format not in file_handlers:
raise TypeError(f"Unsupported format: {file_format}")
handler = file_handlers[file_format]
if is_str(file):
obj = handler.load_from_path(file, **kwargs)
elif hasattr(file, "read"):
obj = handler.load_from_fileobj(file, **kwargs)
else:
raise TypeError('"file" must be a filepath str or a file-object')
return obj
def sldump(obj, file=None, file_format=None, **kwargs):
"""Dump data to json/yaml/pickle strings or files.
This method provides a unified api for dumping data as strings or to files,
and also supports custom arguments for each file format.
Args:
obj (any): The python object to be dumped.
file (str or :obj:`Path` or file-like object, optional): If not
specified, then the object is dump to a str, otherwise to a file
specified by the filename or file-like object.
file_format (str, optional): Same as :func:`load`.
Returns:
bool: True for success, False otherwise.
"""
if isinstance(file, Path):
file = str(file)
if file_format is None:
if is_str(file):
file_format = file.split(".")[-1]
elif file is None:
raise ValueError("file_format must be specified since file is None")
if file_format not in file_handlers:
raise TypeError(f"Unsupported format: {file_format}")
handler = file_handlers[file_format]
if file is None:
return handler.dump_to_str(obj, **kwargs)
elif is_str(file):
handler.dump_to_path(obj, file, **kwargs)
elif hasattr(file, "write"):
handler.dump_to_fileobj(obj, file, **kwargs)
else:
raise TypeError('"file" must be a filename str or a file-object')
import json
import time
class TimeCounter:
def __init__(self) -> None:
pass
def clear(self):
self.timedict = {}
self.basetime = time.perf_counter()
def timeit(self, name):
nowtime = time.perf_counter() - self.basetime
self.timedict[name] = nowtime
self.basetime = time.perf_counter()
class TimeHolder:
def __init__(self) -> None:
self.timedict = {}
def update(self, _timedict: dict):
for k, v in _timedict.items():
if k not in self.timedict:
self.timedict[k] = AverageMeter(name=k, val_only=True)
self.timedict[k].update(val=v)
def final_res(self):
return {k: v.avg for k, v in self.timedict.items()}
def __str__(self):
return json.dumps(self.final_res(), indent=2)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=":f", val_only=False):
self.name = name
self.fmt = fmt
self.val_only = val_only
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
if self.val_only:
fmtstr = "{name} {val" + self.fmt + "}"
else:
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
import argparse
import json
import warnings
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Dict, List
import numpy as np
import torch
from transformers import AutoTokenizer
from groundingdino.util.slconfig import SLConfig
def slprint(x, name="x"):
if isinstance(x, (torch.Tensor, np.ndarray)):
print(f"{name}.shape:", x.shape)
elif isinstance(x, (tuple, list)):
print("type x:", type(x))
for i in range(min(10, len(x))):
slprint(x[i], f"{name}[{i}]")
elif isinstance(x, dict):
for k, v in x.items():
slprint(v, f"{name}[{k}]")
else:
print(f"{name}.type:", type(x))
def clean_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k[:7] == "module.":
k = k[7:] # remove `module.`
new_state_dict[k] = v
return new_state_dict
def renorm(
img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
) -> torch.FloatTensor:
# img: tensor(3,H,W) or tensor(B,3,H,W)
# return: same as img
assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
if img.dim() == 3:
assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
img.size(0),
str(img.size()),
)
img_perm = img.permute(1, 2, 0)
mean = torch.Tensor(mean)
std = torch.Tensor(std)
img_res = img_perm * std + mean
return img_res.permute(2, 0, 1)
else: # img.dim() == 4
assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (
img.size(1),
str(img.size()),
)
img_perm = img.permute(0, 2, 3, 1)
mean = torch.Tensor(mean)
std = torch.Tensor(std)
img_res = img_perm * std + mean
return img_res.permute(0, 3, 1, 2)
class CocoClassMapper:
def __init__(self) -> None:
self.category_map_str = {
"1": 1,
"2": 2,
"3": 3,
"4": 4,
"5": 5,
"6": 6,
"7": 7,
"8": 8,
"9": 9,
"10": 10,
"11": 11,
"13": 12,
"14": 13,
"15": 14,
"16": 15,
"17": 16,
"18": 17,
"19": 18,
"20": 19,
"21": 20,
"22": 21,
"23": 22,
"24": 23,
"25": 24,
"27": 25,
"28": 26,
"31": 27,
"32": 28,
"33": 29,
"34": 30,
"35": 31,
"36": 32,
"37": 33,
"38": 34,
"39": 35,
"40": 36,
"41": 37,
"42": 38,
"43": 39,
"44": 40,
"46": 41,
"47": 42,
"48": 43,
"49": 44,
"50": 45,
"51": 46,
"52": 47,
"53": 48,
"54": 49,
"55": 50,
"56": 51,
"57": 52,
"58": 53,
"59": 54,
"60": 55,
"61": 56,
"62": 57,
"63": 58,
"64": 59,
"65": 60,
"67": 61,
"70": 62,
"72": 63,
"73": 64,
"74": 65,
"75": 66,
"76": 67,
"77": 68,
"78": 69,
"79": 70,
"80": 71,
"81": 72,
"82": 73,
"84": 74,
"85": 75,
"86": 76,
"87": 77,
"88": 78,
"89": 79,
"90": 80,
}
self.origin2compact_mapper = {int(k): v - 1 for k, v in self.category_map_str.items()}
self.compact2origin_mapper = {int(v - 1): int(k) for k, v in self.category_map_str.items()}
def origin2compact(self, idx):
return self.origin2compact_mapper[int(idx)]
def compact2origin(self, idx):
return self.compact2origin_mapper[int(idx)]
def to_device(item, device):
if isinstance(item, torch.Tensor):
return item.to(device)
elif isinstance(item, list):
return [to_device(i, device) for i in item]
elif isinstance(item, dict):
return {k: to_device(v, device) for k, v in item.items()}
else:
raise NotImplementedError(
"Call Shilong if you use other containers! type: {}".format(type(item))
)
#
def get_gaussian_mean(x, axis, other_axis, softmax=True):
"""
Args:
x (float): Input images(BxCxHxW)
axis (int): The index for weighted mean
other_axis (int): The other index
Returns: weighted index for axis, BxC
"""
mat2line = torch.sum(x, axis=other_axis)
# mat2line = mat2line / mat2line.mean() * 10
if softmax:
u = torch.softmax(mat2line, axis=2)
else:
u = mat2line / (mat2line.sum(2, keepdim=True) + 1e-6)
size = x.shape[axis]
ind = torch.linspace(0, 1, size).to(x.device)
batch = x.shape[0]
channel = x.shape[1]
index = ind.repeat([batch, channel, 1])
mean_position = torch.sum(index * u, dim=2)
return mean_position
def get_expected_points_from_map(hm, softmax=True):
"""get_gaussian_map_from_points
B,C,H,W -> B,N,2 float(0, 1) float(0, 1)
softargmax function
Args:
hm (float): Input images(BxCxHxW)
Returns:
weighted index for axis, BxCx2. float between 0 and 1.
"""
# hm = 10*hm
B, C, H, W = hm.shape
y_mean = get_gaussian_mean(hm, 2, 3, softmax=softmax) # B,C
x_mean = get_gaussian_mean(hm, 3, 2, softmax=softmax) # B,C
# return torch.cat((x_mean.unsqueeze(-1), y_mean.unsqueeze(-1)), 2)
return torch.stack([x_mean, y_mean], dim=2)
# Positional encoding (section 5.1)
# borrow from nerf
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs["input_dims"]
out_dim = 0
if self.kwargs["include_input"]:
embed_fns.append(lambda x: x)
out_dim += d
max_freq = self.kwargs["max_freq_log2"]
N_freqs = self.kwargs["num_freqs"]
if self.kwargs["log_sampling"]:
freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs["periodic_fns"]:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, i=0):
import torch.nn as nn
if i == -1:
return nn.Identity(), 3
embed_kwargs = {
"include_input": True,
"input_dims": 3,
"max_freq_log2": multires - 1,
"num_freqs": multires,
"log_sampling": True,
"periodic_fns": [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj: eo.embed(x)
return embed, embedder_obj.out_dim
class APOPMeter:
def __init__(self) -> None:
self.tp = 0
self.fp = 0
self.tn = 0
self.fn = 0
def update(self, pred, gt):
"""
Input:
pred, gt: Tensor()
"""
assert pred.shape == gt.shape
self.tp += torch.logical_and(pred == 1, gt == 1).sum().item()
self.fp += torch.logical_and(pred == 1, gt == 0).sum().item()
self.tn += torch.logical_and(pred == 0, gt == 0).sum().item()
self.tn += torch.logical_and(pred == 1, gt == 0).sum().item()
def update_cm(self, tp, fp, tn, fn):
self.tp += tp
self.fp += fp
self.tn += tn
self.tn += fn
def inverse_sigmoid(x, eps=1e-5):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
def get_raw_dict(args):
"""
return the dicf contained in args.
e.g:
>>> with open(path, 'w') as f:
json.dump(get_raw_dict(args), f, indent=2)
"""
if isinstance(args, argparse.Namespace):
return vars(args)
elif isinstance(args, dict):
return args
elif isinstance(args, SLConfig):
return args._cfg_dict
else:
raise NotImplementedError("Unknown type {}".format(type(args)))
def stat_tensors(tensor):
assert tensor.dim() == 1
tensor_sm = tensor.softmax(0)
entropy = (tensor_sm * torch.log(tensor_sm + 1e-9)).sum()
return {
"max": tensor.max(),
"min": tensor.min(),
"mean": tensor.mean(),
"var": tensor.var(),
"std": tensor.var() ** 0.5,
"entropy": entropy,
}
class NiceRepr:
"""Inherit from this class and define ``__nice__`` to "nicely" print your
objects.
Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
If the inheriting class has a ``__len__``, method then the default
``__nice__`` method will return its length.
Example:
>>> class Foo(NiceRepr):
... def __nice__(self):
... return 'info'
>>> foo = Foo()
>>> assert str(foo) == '<Foo(info)>'
>>> assert repr(foo).startswith('<Foo(info) at ')
Example:
>>> class Bar(NiceRepr):
... pass
>>> bar = Bar()
>>> import pytest
>>> with pytest.warns(None) as record:
>>> assert 'object at' in str(bar)
>>> assert 'object at' in repr(bar)
Example:
>>> class Baz(NiceRepr):
... def __len__(self):
... return 5
>>> baz = Baz()
>>> assert str(baz) == '<Baz(5)>'
"""
def __nice__(self):
"""str: a "nice" summary string describing this module"""
if hasattr(self, "__len__"):
# It is a common pattern for objects to use __len__ in __nice__
# As a convenience we define a default __nice__ for these objects
return str(len(self))
else:
# In all other cases force the subclass to overload __nice__
raise NotImplementedError(f"Define the __nice__ method for {self.__class__!r}")
def __repr__(self):
"""str: the string of the module"""
try:
nice = self.__nice__()
classname = self.__class__.__name__
return f"<{classname}({nice}) at {hex(id(self))}>"
except NotImplementedError as ex:
warnings.warn(str(ex), category=RuntimeWarning)
return object.__repr__(self)
def __str__(self):
"""str: the string of the module"""
try:
classname = self.__class__.__name__
nice = self.__nice__()
return f"<{classname}({nice})>"
except NotImplementedError as ex:
warnings.warn(str(ex), category=RuntimeWarning)
return object.__repr__(self)
def ensure_rng(rng=None):
"""Coerces input into a random number generator.
If the input is None, then a global random state is returned.
If the input is a numeric value, then that is used as a seed to construct a
random state. Otherwise the input is returned as-is.
Adapted from [1]_.
Args:
rng (int | numpy.random.RandomState | None):
if None, then defaults to the global rng. Otherwise this can be an
integer or a RandomState class
Returns:
(numpy.random.RandomState) : rng -
a numpy random number generator
References:
.. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501
"""
if rng is None:
rng = np.random.mtrand._rand
elif isinstance(rng, int):
rng = np.random.RandomState(rng)
else:
rng = rng
return rng
def random_boxes(num=1, scale=1, rng=None):
"""Simple version of ``kwimage.Boxes.random``
Returns:
Tensor: shape (n, 4) in x1, y1, x2, y2 format.
References:
https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
Example:
>>> num = 3
>>> scale = 512
>>> rng = 0
>>> boxes = random_boxes(num, scale, rng)
>>> print(boxes)
tensor([[280.9925, 278.9802, 308.6148, 366.1769],
[216.9113, 330.6978, 224.0446, 456.5878],
[405.3632, 196.3221, 493.3953, 270.7942]])
"""
rng = ensure_rng(rng)
tlbr = rng.rand(num, 4).astype(np.float32)
tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
tlbr[:, 0] = tl_x * scale
tlbr[:, 1] = tl_y * scale
tlbr[:, 2] = br_x * scale
tlbr[:, 3] = br_y * scale
boxes = torch.from_numpy(tlbr)
return boxes
class ModelEma(torch.nn.Module):
def __init__(self, model, decay=0.9997, device=None):
super(ModelEma, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
# import ipdb; ipdb.set_trace()
self.decay = decay
self.device = device # perform ema on different device from model if set
if self.device is not None:
self.module.to(device=device)
def _update(self, model, update_fn):
with torch.no_grad():
for ema_v, model_v in zip(
self.module.state_dict().values(), model.state_dict().values()
):
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_v.copy_(update_fn(ema_v, model_v))
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
class BestMetricSingle:
def __init__(self, init_res=0.0, better="large") -> None:
self.init_res = init_res
self.best_res = init_res
self.best_ep = -1
self.better = better
assert better in ["large", "small"]
def isbetter(self, new_res, old_res):
if self.better == "large":
return new_res > old_res
if self.better == "small":
return new_res < old_res
def update(self, new_res, ep):
if self.isbetter(new_res, self.best_res):
self.best_res = new_res
self.best_ep = ep
return True
return False
def __str__(self) -> str:
return "best_res: {}\t best_ep: {}".format(self.best_res, self.best_ep)
def __repr__(self) -> str:
return self.__str__()
def summary(self) -> dict:
return {
"best_res": self.best_res,
"best_ep": self.best_ep,
}
class BestMetricHolder:
def __init__(self, init_res=0.0, better="large", use_ema=False) -> None:
self.best_all = BestMetricSingle(init_res, better)
self.use_ema = use_ema
if use_ema:
self.best_ema = BestMetricSingle(init_res, better)
self.best_regular = BestMetricSingle(init_res, better)
def update(self, new_res, epoch, is_ema=False):
"""
return if the results is the best.
"""
if not self.use_ema:
return self.best_all.update(new_res, epoch)
else:
if is_ema:
self.best_ema.update(new_res, epoch)
return self.best_all.update(new_res, epoch)
else:
self.best_regular.update(new_res, epoch)
return self.best_all.update(new_res, epoch)
def summary(self):
if not self.use_ema:
return self.best_all.summary()
res = {}
res.update({f"all_{k}": v for k, v in self.best_all.summary().items()})
res.update({f"regular_{k}": v for k, v in self.best_regular.summary().items()})
res.update({f"ema_{k}": v for k, v in self.best_ema.summary().items()})
return res
def __repr__(self) -> str:
return json.dumps(self.summary(), indent=2)
def __str__(self) -> str:
return self.__repr__()
def targets_to(targets: List[Dict[str, Any]], device):
"""Moves the target dicts to the given device."""
excluded_keys = [
"questionId",
"tokens_positive",
"strings_positive",
"tokens",
"dataset_name",
"sentence_id",
"original_img_id",
"nb_eval",
"task_id",
"original_id",
"token_span",
"caption",
"dataset_type",
]
return [
{k: v.to(device) if k not in excluded_keys else v for k, v in t.items()} for t in targets
]
def get_phrases_from_posmap(
posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer
):
assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
if posmap.dim() == 1:
non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
return tokenizer.decode(token_ids)
else:
raise NotImplementedError("posmap must be 1-dim")
# -*- coding: utf-8 -*-
"""
@File : visualizer.py
@Time : 2022/04/05 11:39:33
@Author : Shilong Liu
@Contact : slongliu86@gmail.com
"""
import datetime
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib import transforms
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon
from pycocotools import mask as maskUtils
def renorm(
img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
) -> torch.FloatTensor:
# img: tensor(3,H,W) or tensor(B,3,H,W)
# return: same as img
assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
if img.dim() == 3:
assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
img.size(0),
str(img.size()),
)
img_perm = img.permute(1, 2, 0)
mean = torch.Tensor(mean)
std = torch.Tensor(std)
img_res = img_perm * std + mean
return img_res.permute(2, 0, 1)
else: # img.dim() == 4
assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (
img.size(1),
str(img.size()),
)
img_perm = img.permute(0, 2, 3, 1)
mean = torch.Tensor(mean)
std = torch.Tensor(std)
img_res = img_perm * std + mean
return img_res.permute(0, 3, 1, 2)
class ColorMap:
def __init__(self, basergb=[255, 255, 0]):
self.basergb = np.array(basergb)
def __call__(self, attnmap):
# attnmap: h, w. np.uint8.
# return: h, w, 4. np.uint8.
assert attnmap.dtype == np.uint8
h, w = attnmap.shape
res = self.basergb.copy()
res = res[None][None].repeat(h, 0).repeat(w, 1) # h, w, 3
attn1 = attnmap.copy()[..., None] # h, w, 1
res = np.concatenate((res, attn1), axis=-1).astype(np.uint8)
return res
def rainbow_text(x, y, ls, lc, **kw):
"""
Take a list of strings ``ls`` and colors ``lc`` and place them next to each
other, with text ls[i] being shown in color lc[i].
This example shows how to do both vertical and horizontal text, and will
pass all keyword arguments to plt.text, so you can set the font size,
family, etc.
"""
t = plt.gca().transData
fig = plt.gcf()
plt.show()
# horizontal version
for s, c in zip(ls, lc):
text = plt.text(x, y, " " + s + " ", color=c, transform=t, **kw)
text.draw(fig.canvas.get_renderer())
ex = text.get_window_extent()
t = transforms.offset_copy(text._transform, x=ex.width, units="dots")
# #vertical version
# for s,c in zip(ls,lc):
# text = plt.text(x,y," "+s+" ",color=c, transform=t,
# rotation=90,va='bottom',ha='center',**kw)
# text.draw(fig.canvas.get_renderer())
# ex = text.get_window_extent()
# t = transforms.offset_copy(text._transform, y=ex.height, units='dots')
class COCOVisualizer:
def __init__(self, coco=None, tokenlizer=None) -> None:
self.coco = coco
def visualize(self, img, tgt, caption=None, dpi=180, savedir="vis"):
"""
img: tensor(3, H, W)
tgt: make sure they are all on cpu.
must have items: 'image_id', 'boxes', 'size'
"""
plt.figure(dpi=dpi)
plt.rcParams["font.size"] = "5"
ax = plt.gca()
img = renorm(img).permute(1, 2, 0)
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
ax.imshow(img)
self.addtgt(tgt)
if tgt is None:
image_id = 0
elif "image_id" not in tgt:
image_id = 0
else:
image_id = tgt["image_id"]
if caption is None:
savename = "{}/{}-{}.png".format(
savedir, int(image_id), str(datetime.datetime.now()).replace(" ", "-")
)
else:
savename = "{}/{}-{}-{}.png".format(
savedir, caption, int(image_id), str(datetime.datetime.now()).replace(" ", "-")
)
print("savename: {}".format(savename))
os.makedirs(os.path.dirname(savename), exist_ok=True)
plt.savefig(savename)
plt.close()
def addtgt(self, tgt):
""" """
if tgt is None or not "boxes" in tgt:
ax = plt.gca()
if "caption" in tgt:
ax.set_title(tgt["caption"], wrap=True)
ax.set_axis_off()
return
ax = plt.gca()
H, W = tgt["size"]
numbox = tgt["boxes"].shape[0]
color = []
polygons = []
boxes = []
for box in tgt["boxes"].cpu():
unnormbbox = box * torch.Tensor([W, H, W, H])
unnormbbox[:2] -= unnormbbox[2:] / 2
[bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist()
boxes.append([bbox_x, bbox_y, bbox_w, bbox_h])
poly = [
[bbox_x, bbox_y],
[bbox_x, bbox_y + bbox_h],
[bbox_x + bbox_w, bbox_y + bbox_h],
[bbox_x + bbox_w, bbox_y],
]
np_poly = np.array(poly).reshape((4, 2))
polygons.append(Polygon(np_poly))
c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
color.append(c)
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1)
ax.add_collection(p)
p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
ax.add_collection(p)
if "strings_positive" in tgt and len(tgt["strings_positive"]) > 0:
assert (
len(tgt["strings_positive"]) == numbox
), f"{len(tgt['strings_positive'])} = {numbox}, "
for idx, strlist in enumerate(tgt["strings_positive"]):
cate_id = int(tgt["labels"][idx])
_string = str(cate_id) + ":" + " ".join(strlist)
bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
# ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
ax.text(
bbox_x,
bbox_y,
_string,
color="black",
bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1},
)
if "box_label" in tgt:
assert len(tgt["box_label"]) == numbox, f"{len(tgt['box_label'])} = {numbox}, "
for idx, bl in enumerate(tgt["box_label"]):
_string = str(bl)
bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
# ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
ax.text(
bbox_x,
bbox_y,
_string,
color="black",
bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1},
)
if "caption" in tgt:
ax.set_title(tgt["caption"], wrap=True)
# plt.figure()
# rainbow_text(0.0,0.0,"all unicorns poop rainbows ! ! !".split(),
# ['red', 'orange', 'brown', 'green', 'blue', 'purple', 'black'])
if "attn" in tgt:
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
if isinstance(tgt["attn"], tuple):
tgt["attn"] = [tgt["attn"]]
for item in tgt["attn"]:
attn_map, basergb = item
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-3)
attn_map = (attn_map * 255).astype(np.uint8)
cm = ColorMap(basergb)
heatmap = cm(attn_map)
ax.imshow(heatmap)
ax.set_axis_off()
def showAnns(self, anns, draw_bbox=False):
"""
Display the specified annotations.
:param anns (array of object): annotations to display
:return: None
"""
if len(anns) == 0:
return 0
if "segmentation" in anns[0] or "keypoints" in anns[0]:
datasetType = "instances"
elif "caption" in anns[0]:
datasetType = "captions"
else:
raise Exception("datasetType not supported")
if datasetType == "instances":
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in anns:
c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
if "segmentation" in ann:
if type(ann["segmentation"]) == list:
# polygon
for seg in ann["segmentation"]:
poly = np.array(seg).reshape((int(len(seg) / 2), 2))
polygons.append(Polygon(poly))
color.append(c)
else:
# mask
t = self.imgs[ann["image_id"]]
if type(ann["segmentation"]["counts"]) == list:
rle = maskUtils.frPyObjects(
[ann["segmentation"]], t["height"], t["width"]
)
else:
rle = [ann["segmentation"]]
m = maskUtils.decode(rle)
img = np.ones((m.shape[0], m.shape[1], 3))
if ann["iscrowd"] == 1:
color_mask = np.array([2.0, 166.0, 101.0]) / 255
if ann["iscrowd"] == 0:
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:, :, i] = color_mask[i]
ax.imshow(np.dstack((img, m * 0.5)))
if "keypoints" in ann and type(ann["keypoints"]) == list:
# turn skeleton into zero-based index
sks = np.array(self.loadCats(ann["category_id"])[0]["skeleton"]) - 1
kp = np.array(ann["keypoints"])
x = kp[0::3]
y = kp[1::3]
v = kp[2::3]
for sk in sks:
if np.all(v[sk] > 0):
plt.plot(x[sk], y[sk], linewidth=3, color=c)
plt.plot(
x[v > 0],
y[v > 0],
"o",
markersize=8,
markerfacecolor=c,
markeredgecolor="k",
markeredgewidth=2,
)
plt.plot(
x[v > 1],
y[v > 1],
"o",
markersize=8,
markerfacecolor=c,
markeredgecolor=c,
markeredgewidth=2,
)
if draw_bbox:
[bbox_x, bbox_y, bbox_w, bbox_h] = ann["bbox"]
poly = [
[bbox_x, bbox_y],
[bbox_x, bbox_y + bbox_h],
[bbox_x + bbox_w, bbox_y + bbox_h],
[bbox_x + bbox_w, bbox_y],
]
np_poly = np.array(poly).reshape((4, 2))
polygons.append(Polygon(np_poly))
color.append(c)
# p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
# ax.add_collection(p)
p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
ax.add_collection(p)
elif datasetType == "captions":
for ann in anns:
print(ann["caption"])
import os
import random
from typing import List
import torch
def create_positive_map_from_span(tokenized, token_span, max_text_len=256):
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j
Input:
- tokenized:
- input_ids: Tensor[1, ntokens]
- attention_mask: Tensor[1, ntokens]
- token_span: list with length num_boxes.
- each item: [start_idx, end_idx]
"""
positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float)
for j, tok_list in enumerate(token_span):
for (beg, end) in tok_list:
beg_pos = tokenized.char_to_token(beg)
end_pos = tokenized.char_to_token(end - 1)
if beg_pos is None:
try:
beg_pos = tokenized.char_to_token(beg + 1)
if beg_pos is None:
beg_pos = tokenized.char_to_token(beg + 2)
except:
beg_pos = None
if end_pos is None:
try:
end_pos = tokenized.char_to_token(end - 2)
if end_pos is None:
end_pos = tokenized.char_to_token(end - 3)
except:
end_pos = None
if beg_pos is None or end_pos is None:
continue
assert beg_pos is not None and end_pos is not None
if os.environ.get("SHILONG_DEBUG_ONLY_ONE_POS", None) == "TRUE":
positive_map[j, beg_pos] = 1
break
else:
positive_map[j, beg_pos : end_pos + 1].fill_(1)
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
def build_captions_and_token_span(cat_list, force_lowercase):
"""
Return:
captions: str
cat2tokenspan: dict
{
'dog': [[0, 2]],
...
}
"""
cat2tokenspan = {}
captions = ""
for catname in cat_list:
class_name = catname
if force_lowercase:
class_name = class_name.lower()
if "/" in class_name:
class_name_list: List = class_name.strip().split("/")
class_name_list.append(class_name)
class_name: str = random.choice(class_name_list)
tokens_positive_i = []
subnamelist = [i.strip() for i in class_name.strip().split(" ")]
for subname in subnamelist:
if len(subname) == 0:
continue
if len(captions) > 0:
captions = captions + " "
strat_idx = len(captions)
end_idx = strat_idx + len(subname)
tokens_positive_i.append([strat_idx, end_idx])
captions = captions + subname
if len(tokens_positive_i) > 0:
captions = captions + " ."
cat2tokenspan[class_name] = tokens_positive_i
return captions, cat2tokenspan
def build_id2posspan_and_caption(category_dict: dict):
"""Build id2pos_span and caption from category_dict
Args:
category_dict (dict): category_dict
"""
cat_list = [item["name"].lower() for item in category_dict]
id2catname = {item["id"]: item["name"].lower() for item in category_dict}
caption, cat2posspan = build_captions_and_token_span(cat_list, force_lowercase=True)
id2posspan = {catid: cat2posspan[catname] for catid, catname in id2catname.items()}
return id2posspan, caption
torch
torchvision
transformers
addict
yapf
timm
numpy
opencv-python
supervision
pycocotools
\ No newline at end of file
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Modified from
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/setup.py
# https://github.com/facebookresearch/detectron2/blob/main/setup.py
# https://github.com/open-mmlab/mmdetection/blob/master/setup.py
# https://github.com/Oneflow-Inc/libai/blob/main/setup.py
# ------------------------------------------------------------------------------------------------
import glob
import os
import subprocess
import torch
from setuptools import find_packages, setup
# from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
from torch.utils.cpp_extension import ROCM_HOME, CppExtension
# 用fastpt中的CUDAExtension
from fastpt import CUDAExtension
# groundingdino version info
version = "0.1.0"
package_name = "groundingdino"
cwd = os.path.dirname(os.path.abspath(__file__))
sha = "Unknown"
try:
sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
except Exception:
pass
def write_version_file():
version_path = os.path.join(cwd, "groundingdino", "version.py")
with open(version_path, "w") as f:
f.write(f"__version__ = '{version}'\n")
# f.write(f"git_version = {repr(sha)}\n")
requirements = ["torch", "torchvision"]
torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "groundingdino", "models", "GroundingDINO", "csrc")
main_source = os.path.join(extensions_dir, "vision.cpp")
sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob(
os.path.join(extensions_dir, "*.cu")
)
sources = [main_source] + sources
# We need these variables to build with CUDA when we create the Docker image
# It solves https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/53
# and https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/84 when running
# inside a Docker container.
am_i_docker = os.environ.get('AM_I_DOCKER', '').casefold() in ['true', '1', 't']
use_cuda = os.environ.get('BUILD_WITH_CUDA', '').casefold() in ['true', '1', 't']
extension = CppExtension
extra_compile_args = {"cxx": []}
define_macros = []
# CUDA_HOME -> ROCM_HOME
if (torch.cuda.is_available() and ROCM_HOME is not None) or \
(am_i_docker and use_cuda):
print("Compiling with CUDA")
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
else:
print("Compiling without CUDA")
define_macros += [("WITH_HIP", None)]
extra_compile_args["nvcc"] = []
return None
sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir]
ext_modules = [
extension(
"groundingdino._C",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
def parse_requirements(fname="requirements.txt", with_version=True):
"""Parse the package dependencies listed in a requirements file but strips
specific versioning information.
Args:
fname (str): path to requirements file
with_version (bool, default=False): if True include version specs
Returns:
List[str]: list of requirements items
CommandLine:
python -c "import setup; print(setup.parse_requirements())"
"""
import re
import sys
from os.path import exists
require_fpath = fname
def parse_line(line):
"""Parse information from a line in a requirements text file."""
if line.startswith("-r "):
# Allow specifying requirements in other files
target = line.split(" ")[1]
for info in parse_require_file(target):
yield info
else:
info = {"line": line}
if line.startswith("-e "):
info["package"] = line.split("#egg=")[1]
elif "@git+" in line:
info["package"] = line
else:
# Remove versioning from the package
pat = "(" + "|".join([">=", "==", ">"]) + ")"
parts = re.split(pat, line, maxsplit=1)
parts = [p.strip() for p in parts]
info["package"] = parts[0]
if len(parts) > 1:
op, rest = parts[1:]
if ";" in rest:
# Handle platform specific dependencies
# http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
version, platform_deps = map(str.strip, rest.split(";"))
info["platform_deps"] = platform_deps
else:
version = rest # NOQA
info["version"] = (op, version)
yield info
def parse_require_file(fpath):
with open(fpath, "r") as f:
for line in f.readlines():
line = line.strip()
if line and not line.startswith("#"):
for info in parse_line(line):
yield info
def gen_packages_items():
if exists(require_fpath):
for info in parse_require_file(require_fpath):
parts = [info["package"]]
if with_version and "version" in info:
parts.extend(info["version"])
if not sys.version.startswith("3.4"):
# apparently package_deps are broken in 3.4
platform_deps = info.get("platform_deps")
if platform_deps is not None:
parts.append(";" + platform_deps)
item = "".join(parts)
yield item
packages = list(gen_packages_items())
return packages
if __name__ == "__main__":
print(f"Building wheel {package_name}-{version}")
with open("LICENSE", "r", encoding="utf-8") as f:
license = f.read()
write_version_file()
setup(
name="groundingdino",
version="0.1.0",
author="International Digital Economy Academy, Shilong Liu",
url="https://github.com/IDEA-Research/GroundingDINO",
description="open-set object detector",
license=license,
install_requires=parse_requirements("requirements.txt"),
packages=find_packages(
exclude=(
"configs",
"tests",
)
),
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)
# Installation Instructions
## MAM Setup
We use Python 3.9, PyTorch 1.13.1 (CUDA 11.7 build), torchvision 0.14.1, diffusers 0.17.0 for local setup. You may specify the version in the requirements.txt to align with our local setup if you meet any version mismatch issues during the installation process.
### Create a conda environment
```bash
conda create --name mam python=3.9 -y
conda activate mam
```
### Install packages and other dependencies.
```bash
git clone https://github.com/SHI-Labs/Matting-Anything
cd Matting-Anything
# Install all dependencies
pip install -r requirements.txt
# Install segment-anything
python -m pip install -e segment-anything
# Install Grounding DINO
export BUILD_WITH_CUDA=True
export CUDA_HOME=/path/to/cuda/
python -m pip install -e GroundingDINO
#Install diffusers
pip install --upgrade diffusers[torch]
```
More details can be found in [segment anything](https://github.com/facebookresearch/segment-anything#installation) and [ GroundingDINO](https://github.com/IDEA-Research/GroundingDINO#install) if you meet any installation issues.
### Download the pre-trained weights.
```bash
mkdir checkpoints
cd checkpoints
# Download GroundingDINO model
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
# Download MAM models
https://drive.google.com/drive/folders/1Bor2jRE0U-U6PIYaCm6SZY7qu_c1GYfq?usp=sharing
```
## Gradio Setup
You can set up the gradio demo locally by simply running
```bash
python gradio_app.py
```
to launch and play with the demo based on the SAM ViT-B model.
We support 3 prompt types in the local Gradio app for MAM:
1. **scribble_point**: Click a point on the target instance for matting.
2. **scribble_box**: Click on two points, the top-left point and the bottom-right point to represent a bounding box of the target instance.
3. **text**: Send a text prompt to identify the target instance in the `Text Prompt` box.
We support 2 background types to support image composition with the alpha matte output:
1. **real_world_sample**: Randomly select a real-world image from `assets/backgrounds` for composition.
2. **generated_by_text**: Send a background text prompt to create a background image with the stable diffusion model in the `Background Prompt` box.
You can also play with the demo online at [HuggingFace](https://huggingface.co/spaces/shi-labs/Matting-Anything).
MIT License
Copyright (c) 2023 SHI Labs
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# Matting-Anything
Matting-Anything 一种交互式自然图像抠图模型。
## 论文
`Matting Anything`
- https://arxiv.org/abs/2306.05399
## 模型结构
<!-- 此处一句话简要介绍模型结构 -->
<div align=center>
<img src="./doc/Architecture.png"/>
<div >Matting-Anything</div>
</div>
## 算法原理
Matting Anything Model (MAM) 一个通用框架,能够使用单个模型解决各种类型的图像抠图场景。\
MAM 基于 Segment Anything Model (SAM) 的基础上,利用 SAM 输出带有mask的特征图,进一步迭代细化得到最终的 alpha matte。因为冻结了 SAM 部分的参数,整个框架只有 MAM 模块的270万个参数。Matting Anything 整个框架集成了框、点或文本提示等交互方式,而且还启用了SD模型用来生成虚拟背景图片,增加了抠图的可玩性和互动性。
## 环境配置
```
mv matting-anything_pytoch matting-anything # 去框架名后缀
# docker的-v 路径、docker_name和imageID根据实际情况修改
# pip安装时如果出现下载慢可以尝试别的镜像源
```
### Docker(方法一)
<!-- 此处提供[光源](https://www.sourcefind.cn/#/service-details)拉取docker镜像的地址与使用步骤 -->
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.2-py3.10 # 本镜像imageID为:2f1f619d0182
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=16G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --network=host --name docker_name imageID bash
cd /your_code_path/matting-anything
pip install -r requirements.txt
# Install segment-anything
python -m pip install -e segment-anything
# Install GroundingDINO
python -m pip install -e GroundingDINO
# Install diffusers
pip install --upgrade diffusers[torch]
```
### Dockerfile(方法二)
<!-- 此处提供dockerfile的使用方法 -->
```
cd /your_code_path/matting-anything/docker
docker build --no-cache -t codestral:latest .
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=16G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --network=host --name docker_name imageID bash
cd /your_code_path/matting-anything
pip install -r requirements.txt
# Install segment-anything
python -m pip install -e segment-anything
# Install GroundingDINO
python -m pip install -e GroundingDINO
# Install diffusers
pip install --upgrade diffusers[torch]
```
### Anaconda(方法三)
<!-- 此处提供本地配置、编译的详细步骤,例如: -->
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```
DTK驱动: dtk24.04.2
python: python3.10
pytorch: 2.1.0
```
`Tips:以上DTK驱动、python、pytorch等DCU相关工具版本需要严格一一对应`
其它非深度学习库参照requirements.txt安装:
```
pip install -r requirements.txt
# Install segment-anything
python -m pip install -e segment-anything
# Install GroundingDINO
python -m pip install -e GroundingDINO
# Install diffusers
pip install --upgrade diffusers[torch]
```
## 数据集
## 训练
## 推理
<!-- 下载 [sam_vit_h_4b8939.pth](https://github.com/facebookresearch/segment-anything?tab=readme-ov-file#model-checkpoints)
或者从 [SCNet](http://113.200.138.88:18080/aimodels/findsource-dependency/sam_vit_h_4b8939) 上快速下载; -->
下载 [GroundingDINO-T](https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth) 或者从 [SCNet]() 上快速下载;
下载 [MAM](https://drive.google.com/drive/folders/1Bor2jRE0U-U6PIYaCm6SZY7qu_c1GYfq?usp=sharing) 或者从 [SCNet]() 上快速下载。
并放在 ./checkpoints 下。
可视化webui推理:
```
python gradio_app.py --listen
```
<div align=center>
<img src="./doc/webui_result.png" width=600/>
<div >webui界面</div>
</div>
1、上传图片;\
2、选择目标或者区域,可通过标点、画框或者文本输入;\
3、选择替换的背景,可选自有真实图片或者通过文本提示生成的图片。\
4、运行。\
ps:请勿频繁操作;参数自行调整。
## result
<!-- 此处填算法效果测试图(包括输入、输出) -->
<div align=center>
<img src="./doc/demo.jpg" width=600/>
<div >输入</div>
</div>
<div align=center>
<img src="./doc/matte.png" width=600/>
<div >抠图结果</div>
</div>
<div align=center>
<img src="./doc/result.png" width=600/>
<div >替换背景</div>
</div>
### 精度
无。
<!-- | 加速卡 | lpips | clip sim |
| :-----| :----- | :---- |
| K100_AI | 0.115 | 0.977 | -->
<!-- | 单元格 | 单元格 | 单元格 | -->
## 应用场景
### 算法类别
<!-- 超出以上分类的类别命名也可参考此网址中的类别名:https://huggingface.co/ \ -->
`AIGC`
### 热点应用行业
<!-- 应用行业的填写需要做大量调研,从而为使用者提供专业、全面的推荐,除特殊算法,通常推荐数量>=3。 -->
`零售,制造,电商,医疗,教育`
## 源码仓库及问题反馈
<!-- - 此处填本项目gitlab地址 -->
- https://developer.sourcefind.cn/codes/modelzoo/matting-anything_pytorch
## 参考资料
- https://github.com/SHI-Labs/Matting-Anything
# Matting Anything
[![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=XY2Q0HATGOk)
[![HuggingFace Space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)](https://huggingface.co/spaces/shi-labs/Matting-Anything)
[![Framework: PyTorch](https://img.shields.io/badge/Framework-PyTorch-orange.svg)](https://pytorch.org/)
[![License](https://img.shields.io/badge/License-MIT-red.svg)](https://opensource.org/licenses/MIT)
[Jiachen Li](https://chrisjuniorli.github.io/),
[Jitesh Jain](https://praeclarumjj3.github.io/),
[Humphrey Shi](https://www.humphreyshi.com/home)
[[`Project page`](https://chrisjuniorli.github.io/project/Matting-Anything/)]
[[`ArXiv`](https://arxiv.org/abs/2306.05399)]
[[`Pdf`](https://arxiv.org/pdf/2306.05399.pdf)]
[[`Video`](https://www.youtube.com/watch?v=XY2Q0HATGOk)]
[[`Demo`](https://huggingface.co/spaces/shi-labs/Matting-Anything)]
![](./assets/teaser_arxiv_v2.png)
## Updates
- **`2023/07/17`**: Added MAM checkpoints based on SAM ViT-L and SAM ViT-H.
- **`2023/06/28`**: [**Getting Started**](https://github.com/SHI-Labs/Matting-Anything/blob/main/GETTING_STARTED.md) is updated with training and evaluation instructions.
- **`2023/06/09`**: [**HuggingFace Demo**](https://huggingface.co/spaces/shi-labs/Matting-Anything) is released.
- **`2023/06/08`**: [**Arxiv Preprint**](https://arxiv.org/abs/2306.05399) is released.
- **`2023/06/06`**: [**Project Page**](https://chrisjuniorli.github.io/project/Matting-Anything) and [**Demo Video**](https://www.youtube.com/watch?v=XY2Q0HATGOk) are released.
## Contents
- [Matting-Anything](#matting-anything)
- [Installation](#installation)
- [Getting Started](#getting-started)
- [Third-Party Projects](#third-party-projects)
## Matting Anything
### Abstract
In this paper, we propose the Matting Anything Model (MAM), an efficient and versatile framework for estimating the alpha matte of any instance in an image with flexible and interactive visual or linguistic user prompt guidance. MAM offers several significant advantages over previous specialized image matting networks: (i) MAM is capable of dealing with various types of image matting, including semantic, instance, and referring image matting with only a single model; (ii) MAM leverages the feature maps from the Segment Anything Model (SAM) and adopts a lightweight Mask-to-Matte (M2M) module to predict the alpha matte through iterative refinement, which has only 2.7 million trainable parameters. (iii) By incorporating SAM, MAM simplifies the user intervention required for the interactive use of image matting from the trimap to the box, point, or text prompt. We evaluate the performance of MAM on various image matting benchmarks, and the experimental results demonstrate that MAM achieves comparable performance to the state-of-the-art specialized image matting models under different metrics on each benchmark. Overall, MAM shows superior generalization ability and can effectively handle various image matting tasks with fewer parameters, making it a practical solution for unified image matting.
### Architecture
<div align="center">
<img src="assets/arxiv_fix.png" width="100%" height="100%"/>
</div><br/>
The MAM architecture consists of a pre-trained SAM and an M2M module. Given an
input image I, SAM generates the mask prediction for the target instance based on the box or point user prompt. The M2M module takes
the concatenated inputs, including the image, mask, and feature maps, and produces multi-scale predictions αos8, αos4, and αos1. The
iterative refinement process, detailed in Section 3, progressively improves the precision of the final meticulous alpha matte α, incorporating
information from the multi-scale outputs.
### Visualization
<div align="center">
<img src="assets/teaser.gif" width="100%" height="100%"/>
</div>
<div align="center">
<img src="assets/mam_vis_v2.png" width="100%" height="100%"/>
</div><br/>
We provide visualizations of the alpha matte
predictions from SAM and MAM. Notably, we emphasize
the differences in the red boxes. The visualizations demonstrate that MAM achieves improved predictions in the transition areas even without the trimap guidance. Additionally,
MAM effectively addresses some of the holes present in the mask predictions generated by SAM. These visual comparisons highlight the superior performance of MAM in refining and enhancing the quality of alpha matte predictions.
## Installation
Please refer to [Installation Instructions](INSTALL.md) for complete installation instructions for MAM.
## Getting Started
Please refer to [Getting Started](GETTING_STARTED.md) for dataset preparation, training, and inference details of MAM.
## Third-Party Projects
* [Matting-Anything-Colab](https://github.com/camenduru/Matting-Anything-colab) ([@camenduru](https://twitter.com/camenduru))
* [Matting-Anything-Video](https://huggingface.co/spaces/fffiloni/Video-Matting-Anything) ([@fffiloni](https://twitter.com/fffiloni))
## Citation
```bibtex
@article{li2023matting,
title={Matting Anything},
author={Jiachen Li and Jitesh Jain and Humphrey Shi},
journal={arXiv: 2306.05399},
year={2023}
}
```
## Acknowledgement
We thank the authors of [SAM](https://github.com/facebookresearch/segment-anything), [Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything), [MGMatting](https://github.com/yucornetto/MGMatting), and [InstMatt](https://github.com/nowsyn/InstMatt/tree/main) for releasing the codebases.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment