Commit eb807a19 authored by mashun1's avatar mashun1
Browse files

add icon

parents
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
FastSAM model interface.
Usage - Predict:
from ultralytics import FastSAM
model = FastSAM('last.pt')
results = model.predict('ultralytics/assets/bus.jpg')
"""
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode
from .predict import FastSAMPredictor
class FastSAM(YOLO):
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = self.overrides.copy()
overrides['conf'] = 0.25
overrides.update(kwargs) # prefer kwargs
overrides['mode'] = kwargs.get('mode', 'predict')
assert overrides['mode'] in ['track', 'predict']
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
self.predictor = FastSAMPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model, verbose=False)
try:
return self.predictor(source, stream=stream)
except Exception as e:
return None
def train(self, **kwargs):
"""Function trains models but raises an error as FastSAM models do not support training."""
raise NotImplementedError("Currently, the training codes are on the way.")
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='segment', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = FastSAM(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
import torch
from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CFG, ops
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
from .utils import bbox_iou
class FastSAMPredictor(DetectionPredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'segment'
def postprocess(self, preds, img, orig_imgs):
"""TODO: filter by classes."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=len(self.model.names),
classes=self.args.classes)
results = []
if len(p) == 0 or len(p[0]) == 0:
print("No object detected.")
return results
full_box = torch.zeros_like(p[0][0])
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
full_box = full_box.view(1, -1)
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
if critical_iou_index.numel() != 0:
full_box[0][4] = p[0][critical_iou_index][:,4]
full_box[0][6:] = p[0][critical_iou_index][:,6:]
p[0][critical_iou_index] = full_box
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
for i, pred in enumerate(p):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
if not len(pred): # save empty boxes
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
continue
if self.args.retina_masks:
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
else:
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
results.append(
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
return results
import os
import sys
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from .utils import image_to_np_ndarray
from PIL import Image
try:
import clip # for linear_assignment
except (ImportError, AssertionError, AttributeError):
from ultralytics.yolo.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
import clip
class FastSAMPrompt:
def __init__(self, image, results, device='cuda'):
if isinstance(image, str) or isinstance(image, Image.Image):
image = image_to_np_ndarray(image)
self.device = device
self.results = results
self.img = image
def _segment_image(self, image, bbox):
if isinstance(image, Image.Image):
image_array = np.array(image)
else:
image_array = image
segmented_image_array = np.zeros_like(image_array)
x1, y1, x2, y2 = bbox
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
segmented_image = Image.fromarray(segmented_image_array)
black_image = Image.new('RGB', image.size, (255, 255, 255))
# transparency_mask = np.zeros_like((), dtype=np.uint8)
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
transparency_mask[y1:y2, x1:x2] = 255
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
black_image.paste(segmented_image, mask=transparency_mask_image)
return black_image
def _format_results(self, result, filter=0):
annotations = []
n = len(result.masks.data)
for i in range(n):
annotation = {}
mask = result.masks.data[i] == 1.0
if torch.sum(mask) < filter:
continue
annotation['id'] = i
annotation['segmentation'] = mask.cpu().numpy()
annotation['bbox'] = result.boxes.data[i]
annotation['score'] = result.boxes.conf[i]
annotation['area'] = annotation['segmentation'].sum()
annotations.append(annotation)
return annotations
def filter_masks(annotations): # filte the overlap mask
annotations.sort(key=lambda x: x['area'], reverse=True)
to_remove = set()
for i in range(0, len(annotations)):
a = annotations[i]
for j in range(i + 1, len(annotations)):
b = annotations[j]
if i != j and j not in to_remove:
# check if
if b['area'] < a['area']:
if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
to_remove.add(j)
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
def _get_bbox_from_mask(self, mask):
mask = mask.astype(np.uint8)
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
x1, y1, w, h = cv2.boundingRect(contours[0])
x2, y2 = x1 + w, y1 + h
if len(contours) > 1:
for b in contours:
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
# Merge multiple bounding boxes into one.
x1 = min(x1, x_t)
y1 = min(y1, y_t)
x2 = max(x2, x_t + w_t)
y2 = max(y2, y_t + h_t)
h = y2 - y1
w = x2 - x1
return [x1, y1, x2, y2]
def plot_to_result(self,
annotations,
bboxes=None,
points=None,
point_label=None,
mask_random_color=True,
better_quality=True,
retina=False,
withContours=True) -> np.ndarray:
if isinstance(annotations[0], dict):
annotations = [annotation['segmentation'] for annotation in annotations]
image = self.img
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_h = image.shape[0]
original_w = image.shape[1]
if sys.platform == "darwin":
plt.switch_backend("TkAgg")
plt.figure(figsize=(original_w / 100, original_h / 100))
# Add subplot with no margin.
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(image)
if better_quality:
if isinstance(annotations[0], torch.Tensor):
annotations = np.array(annotations.cpu())
for i, mask in enumerate(annotations):
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
if self.device == 'cpu':
annotations = np.array(annotations)
self.fast_show_mask(
annotations,
plt.gca(),
random_color=mask_random_color,
bboxes=bboxes,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
else:
if isinstance(annotations[0], np.ndarray):
annotations = torch.from_numpy(annotations)
self.fast_show_mask_gpu(
annotations,
plt.gca(),
random_color=mask_random_color,
bboxes=bboxes,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
if isinstance(annotations, torch.Tensor):
annotations = annotations.cpu().numpy()
if withContours:
contour_all = []
temp = np.zeros((original_h, original_w, 1))
for i, mask in enumerate(annotations):
if type(mask) == dict:
mask = mask['segmentation']
annotation = mask.astype(np.uint8)
if not retina:
annotation = cv2.resize(
annotation,
(original_w, original_h),
interpolation=cv2.INTER_NEAREST,
)
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
contour_all.append(contour)
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
plt.axis('off')
fig = plt.gcf()
plt.draw()
try:
buf = fig.canvas.tostring_rgb()
except AttributeError:
fig.canvas.draw()
buf = fig.canvas.tostring_rgb()
cols, rows = fig.canvas.get_width_height()
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
plt.close()
return result
# Remark for refactoring: IMO a function should do one thing only, storing the image and plotting should be seperated and do not necessarily need to be class functions but standalone utility functions that the user can chain in his scripts to have more fine-grained control.
def plot(self,
annotations,
output_path,
bboxes=None,
points=None,
point_label=None,
mask_random_color=True,
better_quality=True,
retina=False,
withContours=True):
if len(annotations) == 0:
return None
result = self.plot_to_result(
annotations,
bboxes,
points,
point_label,
mask_random_color,
better_quality,
retina,
withContours,
)
path = os.path.dirname(os.path.abspath(output_path))
if not os.path.exists(path):
os.makedirs(path)
result = result[:, :, ::-1]
cv2.imwrite(output_path, result)
# CPU post process
def fast_show_mask(
self,
annotation,
ax,
random_color=False,
bboxes=None,
points=None,
pointlabel=None,
retinamask=True,
target_height=960,
target_width=960,
):
msak_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
#Sort annotations based on area.
areas = np.sum(annotation, axis=(1, 2))
sorted_indices = np.argsort(areas)
annotation = annotation[sorted_indices]
index = (annotation != 0).argmax(axis=0)
if random_color:
color = np.random.random((msak_sum, 1, 1, 3))
else:
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual
show = np.zeros((height, weight, 4))
h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# Use vectorized indexing to update the values of 'show'.
show[h_indices, w_indices, :] = mask_image[indices]
if bboxes is not None:
for bbox in bboxes:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
# draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
s=20,
c='y',
)
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
s=20,
c='m',
)
if not retinamask:
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
ax.imshow(show)
def fast_show_mask_gpu(
self,
annotation,
ax,
random_color=False,
bboxes=None,
points=None,
pointlabel=None,
retinamask=True,
target_height=960,
target_width=960,
):
msak_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
areas = torch.sum(annotation, dim=(1, 2))
sorted_indices = torch.argsort(areas, descending=False)
annotation = annotation[sorted_indices]
# Find the index of the first non-zero value at each position.
index = (annotation != 0).to(torch.long).argmax(dim=0)
if random_color:
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
else:
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([
30 / 255, 144 / 255, 255 / 255]).to(annotation.device)
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
visual = torch.cat([color, transparency], dim=-1)
mask_image = torch.unsqueeze(annotation, -1) * visual
# Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form.
show = torch.zeros((height, weight, 4)).to(annotation.device)
try:
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij')
except:
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# Use vectorized indexing to update the values of 'show'.
show[h_indices, w_indices, :] = mask_image[indices]
show_cpu = show.cpu().numpy()
if bboxes is not None:
for bbox in bboxes:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
# draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
s=20,
c='y',
)
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
s=20,
c='m',
)
if not retinamask:
show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
ax.imshow(show_cpu)
# clip
@torch.no_grad()
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
preprocessed_images = [preprocess(image).to(device) for image in elements]
tokenized_text = clip.tokenize([search_text]).to(device)
stacked_images = torch.stack(preprocessed_images)
image_features = model.encode_image(stacked_images)
text_features = model.encode_text(tokenized_text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
probs = 100.0 * image_features @ text_features.T
return probs[:, 0].softmax(dim=0)
def _crop_image(self, format_results):
image = Image.fromarray(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB))
ori_w, ori_h = image.size
annotations = format_results
mask_h, mask_w = annotations[0]['segmentation'].shape
if ori_w != mask_w or ori_h != mask_h:
image = image.resize((mask_w, mask_h))
cropped_boxes = []
cropped_images = []
not_crop = []
filter_id = []
# annotations, _ = filter_masks(annotations)
# filter_id = list(_)
for _, mask in enumerate(annotations):
if np.sum(mask['segmentation']) <= 100:
filter_id.append(_)
continue
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
cropped_boxes.append(self._segment_image(image, bbox))
# cropped_boxes.append(segment_image(image,mask["segmentation"]))
cropped_images.append(bbox) # Save the bounding box of the cropped image.
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
def box_prompt(self, bbox=None, bboxes=None):
if self.results == None:
return []
assert bbox or bboxes
if bboxes is None:
bboxes = [bbox]
max_iou_index = []
for bbox in bboxes:
assert (bbox[2] != 0 and bbox[3] != 0)
masks = self.results[0].masks.data
target_height = self.img.shape[0]
target_width = self.img.shape[1]
h = masks.shape[1]
w = masks.shape[2]
if h != target_height or w != target_width:
bbox = [
int(bbox[0] * w / target_width),
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ]
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
orig_masks_area = torch.sum(masks, dim=(1, 2))
union = bbox_area + orig_masks_area - masks_area
IoUs = masks_area / union
max_iou_index.append(int(torch.argmax(IoUs)))
max_iou_index = list(set(max_iou_index))
return np.array(masks[max_iou_index].cpu().numpy())
def point_prompt(self, points, pointlabel): # numpy
if self.results == None:
return []
masks = self._format_results(self.results[0], 0)
target_height = self.img.shape[0]
target_width = self.img.shape[1]
h = masks[0]['segmentation'].shape[0]
w = masks[0]['segmentation'].shape[1]
if h != target_height or w != target_width:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
onemask = np.zeros((h, w))
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
for i, annotation in enumerate(masks):
if type(annotation) == dict:
mask = annotation['segmentation']
else:
mask = annotation
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask[mask] = 1
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
onemask[mask] = 0
onemask = onemask >= 1
return np.array([onemask])
def text_prompt(self, text):
if self.results == None:
return []
format_results = self._format_results(self.results[0], 0)
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
clip_model, preprocess = clip.load('ViT-B/32', device=self.device)
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
max_idx = scores.argsort()
max_idx = max_idx[-1]
max_idx += sum(np.array(filter_id) <= int(max_idx))
return np.array([annotations[max_idx]['segmentation']])
def everything_prompt(self):
if self.results == None:
return []
return self.results[0].masks.data
import numpy as np
import torch
from PIL import Image
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
'''Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args:
boxes: (n, 4)
image_shape: (height, width)
threshold: pixel threshold
Returns:
adjusted_boxes: adjusted bounding boxes
'''
# Image dimensions
h, w = image_shape
# Adjust boxes
boxes[:, 0] = torch.where(boxes[:, 0] < threshold, torch.tensor(
0, dtype=torch.float, device=boxes.device), boxes[:, 0]) # x1
boxes[:, 1] = torch.where(boxes[:, 1] < threshold, torch.tensor(
0, dtype=torch.float, device=boxes.device), boxes[:, 1]) # y1
boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, torch.tensor(
w, dtype=torch.float, device=boxes.device), boxes[:, 2]) # x2
boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, torch.tensor(
h, dtype=torch.float, device=boxes.device), boxes[:, 3]) # y2
return boxes
def convert_box_xywh_to_xyxy(box):
x1 = box[0]
y1 = box[1]
x2 = box[0] + box[2]
y2 = box[1] + box[3]
return [x1, y1, x2, y2]
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
'''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
Args:
box1: (4, )
boxes: (n, 4)
Returns:
high_iou_indices: Indices of boxes with IoU > thres
'''
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
# obtain coordinates for intersections
x1 = torch.max(box1[0], boxes[:, 0])
y1 = torch.max(box1[1], boxes[:, 1])
x2 = torch.min(box1[2], boxes[:, 2])
y2 = torch.min(box1[3], boxes[:, 3])
# compute the area of intersection
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
# compute the area of both individual boxes
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
# compute the area of union
union = box1_area + box2_area - intersection
# compute the IoU
iou = intersection / union # Should be shape (n, )
if raw_output:
if iou.numel() == 0:
return 0
return iou
# get indices of boxes with IoU > thres
high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
return high_iou_indices
def image_to_np_ndarray(image):
if type(image) is str:
return np.array(Image.open(image))
elif issubclass(type(image), Image.Image):
return np.array(image)
elif type(image) is np.ndarray:
return image
return None
icon.png

66.7 KB

# 模型唯一标识
modelCode = 528
# 模型名称
modelName=fastsam_pytorch
# 模型描述
modelDescription=FastSAM基于YOLACT方法的实例分割分支的目标检测器YOLOv8-seg,通过仅在SA-1B数据集的2%(1/50)上直接训练该CNN检测器,它实现了与SAM相当的性能。
# 应用场景
appScenario=推理,训练,金融,交通,教育
# 框架类型
frameType=pytorch
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
# Thanks for chenxwh.
import argparse
import cv2
import shutil
import ast
from cog import BasePredictor, Input, Path
from ultralytics import YOLO
from utils.tools import *
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.models = {k: YOLO(f"{k}.pt") for k in ["FastSAM-s", "FastSAM-x"]}
def predict(
self,
input_image: Path = Input(description="Input image"),
model_name: str = Input(
description="choose a model",
choices=["FastSAM-x", "FastSAM-s"],
default="FastSAM-x",
),
iou: float = Input(
description="iou threshold for filtering the annotations", default=0.7
),
text_prompt: str = Input(
description='use text prompt eg: "a black dog"', default=None
),
conf: float = Input(description="object confidence threshold", default=0.25),
retina: bool = Input(
description="draw high-resolution segmentation masks", default=True
),
box_prompt: str = Input(default="[0,0,0,0]", description="[x,y,w,h]"),
point_prompt: str = Input(default="[[0,0]]", description="[[x1,y1],[x2,y2]]"),
point_label: str = Input(default="[0]", description="[1,0] 0:background, 1:foreground"),
withContours: bool = Input(
description="draw the edges of the masks", default=False
),
better_quality: bool = Input(
description="better quality using morphologyEx", default=False
),
) -> Path:
"""Run a single prediction on the model"""
# default params
out_path = "output"
if os.path.exists(out_path):
shutil.rmtree(out_path)
os.makedirs(out_path, exist_ok=True)
device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
args = argparse.Namespace(
better_quality=better_quality,
box_prompt=box_prompt,
conf=conf,
device=device,
img_path=str(input_image),
imgsz=1024,
iou=iou,
model_path="FastSAM-x.pt",
output=out_path,
point_label=point_label,
point_prompt=point_prompt,
randomcolor=True,
retina=retina,
text_prompt=text_prompt,
withContours=withContours,
)
args.point_prompt = ast.literal_eval(args.point_prompt)
args.box_prompt = ast.literal_eval(args.box_prompt)
args.point_label = ast.literal_eval(args.point_label)
model = self.models[model_name]
results = model(
str(input_image),
imgsz=args.imgsz,
device=args.device,
retina_masks=args.retina,
iou=args.iou,
conf=args.conf,
max_det=100,
)
if args.box_prompt[2] != 0 and args.box_prompt[3] != 0:
annotations = prompt(results, args, box=True)
annotations = np.array([annotations])
fast_process(
annotations=annotations,
args=args,
mask_random_color=args.randomcolor,
bbox=convert_box_xywh_to_xyxy(args.box_prompt),
)
elif args.text_prompt != None:
results = format_results(results[0], 0)
annotations = prompt(results, args, text=True)
annotations = np.array([annotations])
fast_process(
annotations=annotations, args=args, mask_random_color=args.randomcolor
)
elif args.point_prompt[0] != [0, 0]:
results = format_results(results[0], 0)
annotations = prompt(results, args, point=True)
# list to numpy
annotations = np.array([annotations])
fast_process(
annotations=annotations,
args=args,
mask_random_color=args.randomcolor,
points=args.point_prompt,
)
else:
fast_process(
annotations=results[0].masks.data,
args=args,
mask_random_color=args.randomcolor,
)
out = "/tmp.out.png"
shutil.copy(os.path.join(out_path, os.listdir(out_path)[0]), out)
return Path(out)
def prompt(results, args, box=None, point=None, text=None):
ori_img = cv2.imread(args.img_path)
ori_h = ori_img.shape[0]
ori_w = ori_img.shape[1]
if box:
mask, idx = box_prompt(
results[0].masks.data,
convert_box_xywh_to_xyxy(args.box_prompt),
ori_h,
ori_w,
)
elif point:
mask, idx = point_prompt(
results, args.point_prompt, args.point_label, ori_h, ori_w
)
elif text:
mask, idx = text_prompt(results, args.text_prompt, args.img_path, args.device)
else:
return None
return mask
# Base-----------------------------------
matplotlib>=3.2.2
opencv-python>=4.6.0
Pillow>=7.1.2
PyYAML>=5.3.1
requests>=2.23.0
scipy>=1.4.1
torch>=1.7.0
torchvision>=0.8.1
tqdm>=4.64.0
pandas>=1.1.4
seaborn>=0.11.0
gradio==3.35.2
# Ultralytics-----------------------------------
ultralytics == 8.0.120
from fastsam import FastSAM, FastSAMPrompt
import torch
model = FastSAM('FastSAM.pt')
IMAGE_PATH = './images/dogs.jpg'
DEVICE = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
everything_results = model(
IMAGE_PATH,
device=DEVICE,
retina_masks=True,
imgsz=1024,
conf=0.4,
iou=0.9,
)
prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE)
# # everything prompt
ann = prompt_process.everything_prompt()
# # bbox prompt
# # bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
# bboxes default shape [[0,0,0,0]] -> [[x1,y1,x2,y2]]
# ann = prompt_process.box_prompt(bbox=[200, 200, 300, 300])
# ann = prompt_process.box_prompt(bboxes=[[200, 200, 300, 300], [500, 500, 600, 600]])
# # text prompt
# ann = prompt_process.text_prompt(text='a photo of a dog')
# # point prompt
# # points default [[0,0]] [[x1,y1],[x2,y2]]
# # point_label default [0] [1,0] 0:background, 1:foreground
# ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1])
# point prompt
# points default [[0,0]] [[x1,y1],[x2,y2]]
# point_label default [0] [1,0] 0:background, 1:foreground
ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1])
prompt_process.plot(
annotations=ann,
output='./output/',
mask_random_color=True,
better_quality=True,
retina=False,
withContours=True,
)
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from setuptools import find_packages, setup
REQUIREMENTS = [i.strip() for i in open("requirements.txt").readlines()]
REQUIREMENTS += [
"CLIP @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33#egg=CLIP"
]
setup(
name="fastsam",
version="0.1.1",
install_requires=REQUIREMENTS,
packages=["fastsam", "fastsam_tools"],
package_dir= {
"fastsam": "fastsam",
"fastsam_tools": "utils",
},
url="https://github.com/CASIA-IVA-Lab/FastSAM"
)
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import torch
import os
import sys
import clip
def convert_box_xywh_to_xyxy(box):
if len(box) == 4:
return [box[0], box[1], box[0] + box[2], box[1] + box[3]]
else:
result = []
for b in box:
b = convert_box_xywh_to_xyxy(b)
result.append(b)
return result
def segment_image(image, bbox):
image_array = np.array(image)
segmented_image_array = np.zeros_like(image_array)
x1, y1, x2, y2 = bbox
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
segmented_image = Image.fromarray(segmented_image_array)
black_image = Image.new("RGB", image.size, (255, 255, 255))
# transparency_mask = np.zeros_like((), dtype=np.uint8)
transparency_mask = np.zeros(
(image_array.shape[0], image_array.shape[1]), dtype=np.uint8
)
transparency_mask[y1:y2, x1:x2] = 255
transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
black_image.paste(segmented_image, mask=transparency_mask_image)
return black_image
def format_results(result, filter=0):
annotations = []
n = len(result.masks.data)
for i in range(n):
annotation = {}
mask = result.masks.data[i] == 1.0
if torch.sum(mask) < filter:
continue
annotation["id"] = i
annotation["segmentation"] = mask.cpu().numpy()
annotation["bbox"] = result.boxes.data[i]
annotation["score"] = result.boxes.conf[i]
annotation["area"] = annotation["segmentation"].sum()
annotations.append(annotation)
return annotations
def filter_masks(annotations): # filter the overlap mask
annotations.sort(key=lambda x: x["area"], reverse=True)
to_remove = set()
for i in range(0, len(annotations)):
a = annotations[i]
for j in range(i + 1, len(annotations)):
b = annotations[j]
if i != j and j not in to_remove:
# check if
if b["area"] < a["area"]:
if (a["segmentation"] & b["segmentation"]).sum() / b[
"segmentation"
].sum() > 0.8:
to_remove.add(j)
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
def get_bbox_from_mask(mask):
mask = mask.astype(np.uint8)
contours, hierarchy = cv2.findContours(
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
x1, y1, w, h = cv2.boundingRect(contours[0])
x2, y2 = x1 + w, y1 + h
if len(contours) > 1:
for b in contours:
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
# 将多个bbox合并成一个
x1 = min(x1, x_t)
y1 = min(y1, y_t)
x2 = max(x2, x_t + w_t)
y2 = max(y2, y_t + h_t)
h = y2 - y1
w = x2 - x1
return [x1, y1, x2, y2]
def fast_process(
annotations, args, mask_random_color, bbox=None, points=None, edges=False
):
if isinstance(annotations[0], dict):
annotations = [annotation["segmentation"] for annotation in annotations]
result_name = os.path.basename(args.img_path)
image = cv2.imread(args.img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_h = image.shape[0]
original_w = image.shape[1]
if sys.platform == "darwin":
plt.switch_backend("TkAgg")
plt.figure(figsize=(original_w/100, original_h/100))
# Add subplot with no margin.
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(image)
if args.better_quality == True:
if isinstance(annotations[0], torch.Tensor):
annotations = np.array(annotations.cpu())
for i, mask in enumerate(annotations):
mask = cv2.morphologyEx(
mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
)
annotations[i] = cv2.morphologyEx(
mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
)
if args.device == "cpu":
annotations = np.array(annotations)
fast_show_mask(
annotations,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
points=points,
point_label=args.point_label,
retinamask=args.retina,
target_height=original_h,
target_width=original_w,
)
else:
if isinstance(annotations[0], np.ndarray):
annotations = torch.from_numpy(annotations)
fast_show_mask_gpu(
annotations,
plt.gca(),
random_color=args.randomcolor,
bbox=bbox,
points=points,
point_label=args.point_label,
retinamask=args.retina,
target_height=original_h,
target_width=original_w,
)
if isinstance(annotations, torch.Tensor):
annotations = annotations.cpu().numpy()
if args.withContours == True:
contour_all = []
temp = np.zeros((original_h, original_w, 1))
for i, mask in enumerate(annotations):
if type(mask) == dict:
mask = mask["segmentation"]
annotation = mask.astype(np.uint8)
if args.retina == False:
annotation = cv2.resize(
annotation,
(original_w, original_h),
interpolation=cv2.INTER_NEAREST,
)
contours, hierarchy = cv2.findContours(
annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
for contour in contours:
contour_all.append(contour)
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
save_path = args.output
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.axis("off")
fig = plt.gcf()
plt.draw()
try:
buf = fig.canvas.tostring_rgb()
except AttributeError:
fig.canvas.draw()
buf = fig.canvas.tostring_rgb()
cols, rows = fig.canvas.get_width_height()
img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
# CPU post process
def fast_show_mask(
annotation,
ax,
random_color=False,
bbox=None,
points=None,
point_label=None,
retinamask=True,
target_height=960,
target_width=960,
):
msak_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
# 将annotation 按照面积 排序
areas = np.sum(annotation, axis=(1, 2))
sorted_indices = np.argsort(areas)
annotation = annotation[sorted_indices]
index = (annotation != 0).argmax(axis=0)
if random_color == True:
color = np.random.random((msak_sum, 1, 1, 3))
else:
color = np.ones((msak_sum, 1, 1, 3)) * np.array(
[30 / 255, 144 / 255, 255 / 255]
)
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual
show = np.zeros((height, weight, 4))
h_indices, w_indices = np.meshgrid(
np.arange(height), np.arange(weight), indexing="ij"
)
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# 使用向量化索引更新show的值
show[h_indices, w_indices, :] = mask_image[indices]
if bbox is not None:
x1, y1, x2, y2 = bbox
ax.add_patch(
plt.Rectangle(
(x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
)
)
# draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if point_label[i] == 1],
[point[1] for i, point in enumerate(points) if point_label[i] == 1],
s=20,
c="y",
)
plt.scatter(
[point[0] for i, point in enumerate(points) if point_label[i] == 0],
[point[1] for i, point in enumerate(points) if point_label[i] == 0],
s=20,
c="m",
)
if retinamask == False:
show = cv2.resize(
show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
)
ax.imshow(show)
def fast_show_mask_gpu(
annotation,
ax,
random_color=False,
bbox=None,
points=None,
point_label=None,
retinamask=True,
target_height=960,
target_width=960,
):
msak_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
areas = torch.sum(annotation, dim=(1, 2))
sorted_indices = torch.argsort(areas, descending=False)
annotation = annotation[sorted_indices]
# 找每个位置第一个非零值下标
index = (annotation != 0).to(torch.long).argmax(dim=0)
if random_color == True:
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
else:
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
[30 / 255, 144 / 255, 255 / 255]
).to(annotation.device)
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
visual = torch.cat([color, transparency], dim=-1)
mask_image = torch.unsqueeze(annotation, -1) * visual
# 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
show = torch.zeros((height, weight, 4)).to(annotation.device)
h_indices, w_indices = torch.meshgrid(
torch.arange(height), torch.arange(weight), indexing="ij"
)
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# 使用向量化索引更新show的值
show[h_indices, w_indices, :] = mask_image[indices]
show_cpu = show.cpu().numpy()
if bbox is not None:
x1, y1, x2, y2 = bbox
ax.add_patch(
plt.Rectangle(
(x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
)
)
# draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if point_label[i] == 1],
[point[1] for i, point in enumerate(points) if point_label[i] == 1],
s=20,
c="y",
)
plt.scatter(
[point[0] for i, point in enumerate(points) if point_label[i] == 0],
[point[1] for i, point in enumerate(points) if point_label[i] == 0],
s=20,
c="m",
)
if retinamask == False:
show_cpu = cv2.resize(
show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
)
ax.imshow(show_cpu)
# clip
@torch.no_grad()
def retriev(
model, preprocess, elements: [Image.Image], search_text: str, device
):
preprocessed_images = [preprocess(image).to(device) for image in elements]
tokenized_text = clip.tokenize([search_text]).to(device)
stacked_images = torch.stack(preprocessed_images)
image_features = model.encode_image(stacked_images)
text_features = model.encode_text(tokenized_text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
probs = 100.0 * image_features @ text_features.T
return probs[:, 0].softmax(dim=0)
def crop_image(annotations, image_like):
if isinstance(image_like, str):
image = Image.open(image_like)
else:
image = image_like
ori_w, ori_h = image.size
mask_h, mask_w = annotations[0]["segmentation"].shape
if ori_w != mask_w or ori_h != mask_h:
image = image.resize((mask_w, mask_h))
cropped_boxes = []
cropped_images = []
not_crop = []
origin_id = []
for _, mask in enumerate(annotations):
if np.sum(mask["segmentation"]) <= 100:
continue
origin_id.append(_)
bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
# cropped_boxes.append(segment_image(image,mask["segmentation"]))
cropped_images.append(bbox) # 保存裁剪的图片的bbox
return cropped_boxes, cropped_images, not_crop, origin_id, annotations
def box_prompt(masks, bbox, target_height, target_width):
h = masks.shape[1]
w = masks.shape[2]
if h != target_height or w != target_width:
bbox = [
int(bbox[0] * w / target_width),
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height),
]
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
orig_masks_area = torch.sum(masks, dim=(1, 2))
union = bbox_area + orig_masks_area - masks_area
IoUs = masks_area / union
max_iou_index = torch.argmax(IoUs)
return masks[max_iou_index].cpu().numpy(), max_iou_index
def point_prompt(masks, points, point_label, target_height, target_width): # numpy 处理
h = masks[0]["segmentation"].shape[0]
w = masks[0]["segmentation"].shape[1]
if h != target_height or w != target_width:
points = [
[int(point[0] * w / target_width), int(point[1] * h / target_height)]
for point in points
]
onemask = np.zeros((h, w))
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
for i, annotation in enumerate(masks):
if type(annotation) == dict:
mask = annotation['segmentation']
else:
mask = annotation
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and point_label[i] == 1:
onemask[mask] = 1
if mask[point[1], point[0]] == 1 and point_label[i] == 0:
onemask[mask] = 0
onemask = onemask >= 1
return onemask, 0
def text_prompt(annotations, text, img_path, device, wider=False, threshold=0.9):
cropped_boxes, cropped_images, not_crop, origin_id, annotations_ = crop_image(
annotations, img_path
)
clip_model, preprocess = clip.load("ViT-B/32", device=device)
scores = retriev(
clip_model, preprocess, cropped_boxes, text, device=device
)
max_idx = scores.argsort()
max_idx = max_idx[-1]
max_idx = origin_id[int(max_idx)]
# find the biggest mask which contains the mask with max score
if wider:
mask0 = annotations_[max_idx]["segmentation"]
area0 = np.sum(mask0)
areas = [(i, np.sum(mask["segmentation"])) for i, mask in enumerate(annotations_) if i in origin_id]
areas = sorted(areas, key=lambda area: area[1], reverse=True)
indices = [area[0] for area in areas]
for index in indices:
if index == max_idx or np.sum(annotations_[index]["segmentation"] & mask0) / area0 > threshold:
max_idx = index
break
return annotations_[max_idx]["segmentation"], max_idx
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import torch
def fast_process(
annotations,
image,
device,
scale,
better_quality=False,
mask_random_color=True,
bbox=None,
use_retina=True,
withContours=True,
):
if isinstance(annotations[0], dict):
annotations = [annotation['segmentation'] for annotation in annotations]
original_h = image.height
original_w = image.width
if better_quality:
if isinstance(annotations[0], torch.Tensor):
annotations = np.array(annotations.cpu())
for i, mask in enumerate(annotations):
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
if device == 'cpu':
annotations = np.array(annotations)
inner_mask = fast_show_mask(
annotations,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
retinamask=use_retina,
target_height=original_h,
target_width=original_w,
)
else:
if isinstance(annotations[0], np.ndarray):
annotations = torch.from_numpy(annotations)
inner_mask = fast_show_mask_gpu(
annotations,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
retinamask=use_retina,
target_height=original_h,
target_width=original_w,
)
if isinstance(annotations, torch.Tensor):
annotations = annotations.cpu().numpy()
if withContours:
contour_all = []
temp = np.zeros((original_h, original_w, 1))
for i, mask in enumerate(annotations):
if type(mask) == dict:
mask = mask['segmentation']
annotation = mask.astype(np.uint8)
if use_retina == False:
annotation = cv2.resize(
annotation,
(original_w, original_h),
interpolation=cv2.INTER_NEAREST,
)
contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
contour_all.append(contour)
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
contour_mask = temp / 255 * color.reshape(1, 1, -1)
image = image.convert('RGBA')
overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
image.paste(overlay_inner, (0, 0), overlay_inner)
if withContours:
overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
image.paste(overlay_contour, (0, 0), overlay_contour)
return image
# CPU post process
def fast_show_mask(
annotation,
ax,
random_color=False,
bbox=None,
retinamask=True,
target_height=960,
target_width=960,
):
mask_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
# 将annotation 按照面积 排序
areas = np.sum(annotation, axis=(1, 2))
sorted_indices = np.argsort(areas)[::1]
annotation = annotation[sorted_indices]
index = (annotation != 0).argmax(axis=0)
if random_color:
color = np.random.random((mask_sum, 1, 1, 3))
else:
color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual
mask = np.zeros((height, weight, 4))
h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
mask[h_indices, w_indices, :] = mask_image[indices]
if bbox is not None:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
if not retinamask:
mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
return mask
def fast_show_mask_gpu(
annotation,
ax,
random_color=False,
bbox=None,
retinamask=True,
target_height=960,
target_width=960,
):
device = annotation.device
mask_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
areas = torch.sum(annotation, dim=(1, 2))
sorted_indices = torch.argsort(areas, descending=False)
annotation = annotation[sorted_indices]
# 找每个位置第一个非零值下标
index = (annotation != 0).to(torch.long).argmax(dim=0)
if random_color:
color = torch.rand((mask_sum, 1, 1, 3)).to(device)
else:
color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
[30 / 255, 144 / 255, 255 / 255]
).to(device)
transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
visual = torch.cat([color, transparency], dim=-1)
mask_image = torch.unsqueeze(annotation, -1) * visual
# 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
mask = torch.zeros((height, weight, 4)).to(device)
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# 使用向量化索引更新show的值
mask[h_indices, w_indices, :] = mask_image[indices]
mask_cpu = mask.cpu().numpy()
if bbox is not None:
x1, y1, x2, y2 = bbox
ax.add_patch(
plt.Rectangle(
(x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
)
)
if not retinamask:
mask_cpu = cv2.resize(
mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
)
return mask_cpu
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment