import matplotlib.pyplot as plt import numpy as np import torch from torchvision.transforms import ToTensor from PIL import Image import io DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(DEVICE) def run_ours_box_or_points(img_path, pts_sampled, pts_labels, model): model = model.to(DEVICE) image_np = np.array(Image.open(img_path)) img_tensor = ToTensor()(image_np) img_tensor = img_tensor.to(DEVICE) pts_sampled = torch.reshape(torch.tensor(pts_sampled), [1, 1, -1, 2]) pts_labels = torch.reshape(torch.tensor(pts_labels), [1, 1, -1]) pts_sampled = pts_sampled.to(DEVICE) pts_labels = pts_labels.to(DEVICE) predicted_logits, predicted_iou = model( img_tensor[None, ...], pts_sampled, pts_labels, ) sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True) predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2) predicted_logits = torch.take_along_dim( predicted_logits, sorted_ids[..., None, None], dim=2 ) return torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy() def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30 / 255, 144 / 255, 255 / 255, 0.8]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels == 1] neg_points = coords[labels == 0] ax.scatter( pos_points[:, 0], pos_points[:, 1], color="green", marker="*", s=marker_size, edgecolor="white", linewidth=1.25, ) ax.scatter( neg_points[:, 0], neg_points[:, 1], color="red", marker="*", s=marker_size, edgecolor="white", linewidth=1.25, ) def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch( plt.Rectangle((x0, y0), w, h, edgecolor="yellow", facecolor=(0, 0, 0, 0), lw=5) ) def show_anns_ours(mask, ax): ax.set_autoscale_on(False) img = np.ones((mask.shape[0], mask.shape[1], 4)) img[:, :, 3] = 0 color_mask = [0, 1, 0, 0.7] img[np.logical_not(mask)] = color_mask ax.imshow(img) from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits # from squeeze_sam.build_squeeze_sam import build_squeeze_sam import zipfile efficient_sam_vitt_model = build_efficient_sam_vitt() efficient_sam_vitt_model.eval() # Since EfficientSAM-S checkpoint file is >100MB, we store the zip file. with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref: zip_ref.extractall("weights") efficient_sam_vits_model = build_efficient_sam_vits() efficient_sam_vits_model.eval() # squeeze_sam_model = build_squeeze_sam() # squeeze_sam_model.eval() x1=400 y1=200 x2=800 y2=600 w=x2-x1 h=y2-y1 fig, ax = plt.subplots(1, 3, figsize=(30, 30)) input_point = np.array([[x1, y1], [x2, y2]]) input_label = np.array([2,3]) image_path = "figs/examples/dogs.jpg" image = np.array(Image.open(image_path)) show_points(input_point, input_label, ax[0]) show_box([x1,y1,x2,y2], ax[0]) ax[0].imshow(image) ax[1].imshow(image) mask_efficient_sam_vitt = run_ours_box_or_points(image_path, input_point, input_label, efficient_sam_vitt_model) show_anns_ours(mask_efficient_sam_vitt, ax[1]) ax[1].title.set_text("EfficientSAM (VIT-tiny)") ax[1].axis('off') ax[2].imshow(image) mask_efficient_sam_vits = run_ours_box_or_points(image_path, input_point, input_label, efficient_sam_vits_model) show_anns_ours(mask_efficient_sam_vits, ax[2]) ax[2].title.set_text("EfficientSAM (VIT-small)") ax[2].axis('off') # ax[3].imshow(image) # mask_squeeze_sam = run_ours_box_or_points(image_path, input_point, input_label, squeeze_sam_model) # show_anns_ours(mask_squeeze_sam, ax[3]) # ax[3].title.set_text("SqueezeSAM") # ax[3].axis('off') plt.savefig("results/efficientsam_box.png", bbox_inches='tight')