Commit b0f7a242 authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #1000 canceled with stages
import math
from typing import Tuple, Type
import torch
from torch import nn, Tensor
from .mlp import MLPBlock
class TwoWayTransformer(nn.Module):
def __init__(
self,
depth: int,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
activation: Type[nn.Module],
normalize_before_activation: bool,
attention_downsample_rate: int = 2,
) -> None:
"""
A transformer decoder that attends to an input image using
queries whose positional embedding is supplied.
Args:
depth (int): number of layers in the transformer
embedding_dim (int): the channel dimension for the input embeddings
num_heads (int): the number of heads for multihead attention. Must
divide embedding_dim
mlp_dim (int): the channel dimension internal to the MLP block
activation (nn.Module): the activation to use in the MLP block
"""
super().__init__()
self.depth = depth
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.layers = nn.ModuleList()
for i in range(depth):
curr_layer = TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
normalize_before_activation=normalize_before_activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
)
self.layers.append(curr_layer)
self.final_attn_token_to_image = AttentionForTwoWayAttentionBlock(
embedding_dim,
num_heads,
downsample_rate=attention_downsample_rate,
)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
def forward(
self,
image_embedding: Tensor,
image_pe: Tensor,
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Prepare queries
queries = point_embedding
keys = image_embedding
# Apply transformer blocks and final layernorm
for idx, layer in enumerate(self.layers):
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
# Apply the final attention layer from the points to the image
q = queries + point_embedding
k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries)
return queries, keys
class TwoWayAttentionBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
activation: Type[nn.Module],
normalize_before_activation: bool,
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Arguments:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
activation (nn.Module): the activation of the mlp block
skip_first_layer_pe (bool): skip the PE on the first layer
"""
super().__init__()
self.self_attn = AttentionForTwoWayAttentionBlock(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = AttentionForTwoWayAttentionBlock(
embedding_dim,
num_heads,
downsample_rate=attention_downsample_rate,
)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(
embedding_dim,
mlp_dim,
embedding_dim,
1,
activation,
)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = AttentionForTwoWayAttentionBlock(
embedding_dim,
num_heads,
downsample_rate=attention_downsample_rate,
)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
) -> Tuple[Tensor, Tensor]:
# Self attention block
if not self.skip_first_layer_pe:
queries = queries + query_pe
attn_out = self.self_attn(q=queries, k=queries, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
class AttentionForTwoWayAttentionBlock(nn.Module):
"""
An attention layer that allows for downscaling the size of the embedding
after projection to queries, keys, and values.
"""
def __init__(
self,
embedding_dim: int,
num_heads: int,
downsample_rate: int = 1,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert (
self.internal_dim % num_heads == 0
), "num_heads must divide embedding_dim."
self.c_per_head = self.internal_dim / num_heads
self.inv_sqrt_c_per_head = 1.0 / math.sqrt(self.c_per_head)
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
self._reset_parameters()
def _reset_parameters(self) -> None:
# The fan_out is incorrect, but matches pytorch's initialization
# for which qkv is a single 3*embedding_dim x embedding_dim matrix
fan_in = self.embedding_dim
fan_out = 3 * self.internal_dim
# Xavier uniform with our custom fan_out
bnd = math.sqrt(6 / (fan_in + fan_out))
nn.init.uniform_(self.q_proj.weight, -bnd, bnd)
nn.init.uniform_(self.k_proj.weight, -bnd, bnd)
nn.init.uniform_(self.v_proj.weight, -bnd, bnd)
# out_proj.weight is left with default initialization, like pytorch attention
nn.init.zeros_(self.q_proj.bias)
nn.init.zeros_(self.k_proj.bias)
nn.init.zeros_(self.v_proj.bias)
nn.init.zeros_(self.out_proj.bias)
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
def _recombine_heads(self, x: Tensor) -> Tensor:
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# Input projections
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
# Attention
_, _, _, c_per_head = q.shape
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
attn = attn * self.inv_sqrt_c_per_head
attn = torch.softmax(attn, dim=-1)
# Get output
out = attn @ v
out = self._recombine_heads(out)
out = self.out_proj(out)
return out
#ONNX export code is from [labelme annotation tool](https://github.com/labelmeai/efficient-sam). Huge thanks to Kentaro Wada.
import onnxruntime
import torch
from efficient_sam.build_efficient_sam import build_efficient_sam_vits
from efficient_sam.build_efficient_sam import build_efficient_sam_vitt
import onnx_models
def export_onnx(onnx_model, output, dynamic_axes, dummy_inputs, output_names):
with open(output, "wb") as f:
print(f"Exporting onnx model to {output}...")
torch.onnx.export(
onnx_model,
tuple(dummy_inputs.values()),
f,
export_params=True,
verbose=False,
opset_version=17,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)
inference_session = onnxruntime.InferenceSession(output)
output = inference_session.run(
output_names=output_names,
input_feed={k: v.numpy() for k, v in dummy_inputs.items()},
)
print(output_names)
print([output_i.shape for output_i in output])
def export_onnx_esam(model, output):
onnx_model = onnx_models.OnnxEfficientSam(model=model)
dynamic_axes = {
"batched_images": {0: "batch", 2: "height", 3: "width"},
"batched_point_coords": {2: "num_points"},
"batched_point_labels": {2: "num_points"},
}
dummy_inputs = {
"batched_images": torch.randn(1, 3, 1080, 1920, dtype=torch.float),
"batched_point_coords": torch.randint(
low=0, high=1080, size=(1, 1, 5, 2), dtype=torch.float
),
"batched_point_labels": torch.randint(
low=0, high=4, size=(1, 1, 5), dtype=torch.float
),
}
output_names = ["output_masks", "iou_predictions"]
export_onnx(
onnx_model=onnx_model,
output=output,
dynamic_axes=dynamic_axes,
dummy_inputs=dummy_inputs,
output_names=output_names,
)
def export_onnx_esam_encoder(model, output):
onnx_model = onnx_models.OnnxEfficientSamEncoder(model=model)
dynamic_axes = {
"batched_images": {0: "batch", 2: "height", 3: "width"},
}
dummy_inputs = {
"batched_images": torch.randn(1, 3, 1080, 1920, dtype=torch.float),
}
output_names = ["image_embeddings"]
export_onnx(
onnx_model=onnx_model,
output=output,
dynamic_axes=dynamic_axes,
dummy_inputs=dummy_inputs,
output_names=output_names,
)
def export_onnx_esam_decoder(model, output):
onnx_model = onnx_models.OnnxEfficientSamDecoder(model=model)
dynamic_axes = {
"image_embeddings": {0: "batch"},
"batched_point_coords": {2: "num_points"},
"batched_point_labels": {2: "num_points"},
}
dummy_inputs = {
"image_embeddings": torch.randn(1, 256, 64, 64, dtype=torch.float),
"batched_point_coords": torch.randint(
low=0, high=1080, size=(1, 1, 5, 2), dtype=torch.float
),
"batched_point_labels": torch.randint(
low=0, high=4, size=(1, 1, 5), dtype=torch.float
),
"orig_im_size": torch.tensor([1080, 1920], dtype=torch.long),
}
output_names = ["output_masks", "iou_predictions"]
export_onnx(
onnx_model=onnx_model,
output=output,
dynamic_axes=dynamic_axes,
dummy_inputs=dummy_inputs,
output_names=output_names,
)
def main():
# faster
export_onnx_esam(
model=build_efficient_sam_vitt(),
output="weights/efficient_sam_vitt.onnx",
)
export_onnx_esam_encoder(
model=build_efficient_sam_vitt(),
output="weights/efficient_sam_vitt_encoder.onnx",
)
export_onnx_esam_decoder(
model=build_efficient_sam_vitt(),
output="weights/efficient_sam_vitt_decoder.onnx",
)
# more accurate
export_onnx_esam(
model=build_efficient_sam_vits(),
output="weights/efficient_sam_vits.onnx",
)
export_onnx_esam_encoder(
model=build_efficient_sam_vits(),
output="weights/efficient_sam_vits_encoder.onnx",
)
export_onnx_esam_decoder(
model=build_efficient_sam_vits(),
output="weights/efficient_sam_vits_decoder.onnx",
)
if __name__ == "__main__":
main()
import torch
from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits
# from squeeze_sam.build_squeeze_sam import build_squeeze_sam
import zipfile
import os
# Efficient SAM (VIT-tiny)
torch.jit.save(torch.jit.script(build_efficient_sam_vitt()), "torchscripted_model/efficient_sam_vitt_torchscript.pt")
# Efficient SAM (VIT-small)
# Since VIT-small is >100MB, we store the zip file.
with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
zip_ref.extractall("weights")
torch.jit.save(torch.jit.script(build_efficient_sam_vits()), "torchscripted_model/efficient_sam_vits_torchscript.pt")
# Squeeze SAM (UNET)
# torch.jit.save(torch.jit.script(build_squeeze_sam()), "torchscripted_model/squeeze_sam_torchscript.pt")
File added
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision.transforms import ToTensor
from PIL import Image
import io
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
def run_ours_box_or_points(img_path, pts_sampled, pts_labels, model):
model = model.to(DEVICE)
image_np = np.array(Image.open(img_path))
img_tensor = ToTensor()(image_np)
img_tensor = img_tensor.to(DEVICE)
pts_sampled = torch.reshape(torch.tensor(pts_sampled), [1, 1, -1, 2])
pts_labels = torch.reshape(torch.tensor(pts_labels), [1, 1, -1])
pts_sampled = pts_sampled.to(DEVICE)
pts_labels = pts_labels.to(DEVICE)
predicted_logits, predicted_iou = model(
img_tensor[None, ...],
pts_sampled,
pts_labels,
)
sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
predicted_logits = torch.take_along_dim(
predicted_logits, sorted_ids[..., None, None], dim=2
)
return torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.8])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
color="green",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
ax.scatter(
neg_points[:, 0],
neg_points[:, 1],
color="red",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(
plt.Rectangle((x0, y0), w, h, edgecolor="yellow", facecolor=(0, 0, 0, 0), lw=5)
)
def show_anns_ours(mask, ax):
ax.set_autoscale_on(False)
img = np.ones((mask.shape[0], mask.shape[1], 4))
img[:, :, 3] = 0
color_mask = [0, 1, 0, 0.7]
img[np.logical_not(mask)] = color_mask
ax.imshow(img)
from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits
# from squeeze_sam.build_squeeze_sam import build_squeeze_sam
import zipfile
efficient_sam_vitt_model = build_efficient_sam_vitt()
efficient_sam_vitt_model.eval()
# Since EfficientSAM-S checkpoint file is >100MB, we store the zip file.
with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
zip_ref.extractall("weights")
efficient_sam_vits_model = build_efficient_sam_vits()
efficient_sam_vits_model.eval()
# squeeze_sam_model = build_squeeze_sam()
# squeeze_sam_model.eval()
x1=400
y1=200
x2=800
y2=600
w=x2-x1
h=y2-y1
fig, ax = plt.subplots(1, 3, figsize=(30, 30))
input_point = np.array([[x1, y1], [x2, y2]])
input_label = np.array([2,3])
image_path = "figs/examples/dogs.jpg"
image = np.array(Image.open(image_path))
show_points(input_point, input_label, ax[0])
show_box([x1,y1,x2,y2], ax[0])
ax[0].imshow(image)
ax[1].imshow(image)
mask_efficient_sam_vitt = run_ours_box_or_points(image_path, input_point, input_label, efficient_sam_vitt_model)
show_anns_ours(mask_efficient_sam_vitt, ax[1])
ax[1].title.set_text("EfficientSAM (VIT-tiny)")
ax[1].axis('off')
ax[2].imshow(image)
mask_efficient_sam_vits = run_ours_box_or_points(image_path, input_point, input_label, efficient_sam_vits_model)
show_anns_ours(mask_efficient_sam_vits, ax[2])
ax[2].title.set_text("EfficientSAM (VIT-small)")
ax[2].axis('off')
# ax[3].imshow(image)
# mask_squeeze_sam = run_ours_box_or_points(image_path, input_point, input_label, squeeze_sam_model)
# show_anns_ours(mask_squeeze_sam, ax[3])
# ax[3].title.set_text("SqueezeSAM")
# ax[3].axis('off')
plt.savefig("results/efficientsam_box.png", bbox_inches='tight')
\ No newline at end of file
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision.transforms import ToTensor
from PIL import Image
import io
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
def run_ours_box_or_points(img_path, pts_sampled, pts_labels, model):
model = model.to(DEVICE)
image_np = np.array(Image.open(img_path))
img_tensor = ToTensor()(image_np)
img_tensor = img_tensor.to(DEVICE)
pts_sampled = torch.reshape(torch.tensor(pts_sampled), [1, 1, -1, 2])
pts_labels = torch.reshape(torch.tensor(pts_labels), [1, 1, -1])
pts_sampled = pts_sampled.to(DEVICE)
pts_labels = pts_labels.to(DEVICE)
predicted_logits, predicted_iou = model(
img_tensor[None, ...],
pts_sampled,
pts_labels,
)
sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
predicted_logits = torch.take_along_dim(
predicted_logits, sorted_ids[..., None, None], dim=2
)
return torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.8])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
color="green",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
ax.scatter(
neg_points[:, 0],
neg_points[:, 1],
color="red",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(
plt.Rectangle((x0, y0), w, h, edgecolor="yellow", facecolor=(0, 0, 0, 0), lw=5)
)
def show_anns_ours(mask, ax):
ax.set_autoscale_on(False)
img = np.ones((mask.shape[0], mask.shape[1], 4))
img[:, :, 3] = 0
color_mask = [0, 1, 0, 0.7]
img[np.logical_not(mask)] = color_mask
ax.imshow(img)
from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits
# from squeeze_sam.build_squeeze_sam import build_squeeze_sam
import zipfile
efficient_sam_vitt_model = build_efficient_sam_vitt()
efficient_sam_vitt_model.eval()
# Since EfficientSAM-S checkpoint file is >100MB, we store the zip file.
with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
zip_ref.extractall("weights")
efficient_sam_vits_model = build_efficient_sam_vits()
efficient_sam_vits_model.eval()
# squeeze_sam_model = build_squeeze_sam()
# squeeze_sam_model.eval()
fig, ax = plt.subplots(1, 3, figsize=(30, 30))
input_label = np.array([1, 1])
image_path = "figs/examples/dogs.jpg"
input_point = np.array([[580, 350], [650, 350]])
image = np.array(Image.open(image_path))
show_points(input_point, input_label, ax[0])
ax[0].imshow(image)
ax[1].imshow(image)
mask_efficient_sam_vitt = run_ours_box_or_points(image_path, input_point, input_label, efficient_sam_vitt_model)
show_anns_ours(mask_efficient_sam_vitt, ax[1])
ax[1].title.set_text("EfficientSAM (VIT-tiny)")
ax[1].axis('off')
ax[2].imshow(image)
mask_efficient_sam_vits = run_ours_box_or_points(image_path, input_point, input_label, efficient_sam_vits_model)
show_anns_ours(mask_efficient_sam_vits, ax[2])
ax[2].title.set_text("EfficientSAM (VIT-small)")
ax[2].axis('off')
# ax[3].imshow(image)
# mask_squeeze_sam = run_ours_box_or_points(image_path, input_point, input_label, squeeze_sam_model)
# show_anns_ours(mask_squeeze_sam, ax[3])
# ax[3].title.set_text("SqueezeSAM")
# ax[3].axis('off')
plt.savefig("results/efficientsam_point.png", bbox_inches='tight')
\ No newline at end of file
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision.transforms import ToTensor
from PIL import Image
import io
import cv2
GRID_SIZE = 32
from segment_anything.utils.amg import (
batched_mask_to_box,
calculate_stability_score,
mask_to_rle_pytorch,
remove_small_regions,
rle_to_mask,
)
from torchvision.ops.boxes import batched_nms, box_area
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def process_small_region(rles):
new_masks = []
scores = []
min_area = 100
nms_thresh = 0.7
for rle in rles:
mask = rle_to_mask(rle[0])
mask, changed = remove_small_regions(mask, min_area, mode="holes")
unchanged = not changed
mask, changed = remove_small_regions(mask, min_area, mode="islands")
unchanged = unchanged and not changed
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
# Give score=0 to changed masks and score=1 to unchanged masks
# so NMS will prefer ones that didn't need postprocessing
scores.append(float(unchanged))
# Recalculate boxes and remove any new duplicates
masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(masks)
keep_by_nms = batched_nms(
boxes.float(),
torch.as_tensor(scores),
torch.zeros_like(boxes[:, 0]), # categories
iou_threshold=nms_thresh,
)
# Only recalculate RLEs for masks that have changed
for i_mask in keep_by_nms:
if scores[i_mask] == 0.0:
mask_torch = masks[i_mask].unsqueeze(0)
rles[i_mask] = mask_to_rle_pytorch(mask_torch)
masks = [rle_to_mask(rles[i][0]) for i in keep_by_nms]
return masks
def get_predictions_given_embeddings_and_queries(img, points, point_labels, model):
predicted_masks, predicted_iou = model(
img[None, ...], points, point_labels
)
sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
predicted_iou_scores = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
predicted_masks = torch.take_along_dim(
predicted_masks, sorted_ids[..., None, None], dim=2
)
predicted_masks = predicted_masks[0]
iou = predicted_iou_scores[0, :, 0]
index_iou = iou > 0.7
iou_ = iou[index_iou]
masks = predicted_masks[index_iou]
score = calculate_stability_score(masks, 0.0, 1.0)
score = score[:, 0]
index = score > 0.9
score_ = score[index]
masks = masks[index]
iou_ = iou_[index]
masks = torch.ge(masks, 0.0)
return masks, iou_
def run_everything_ours(img_path, model):
model = model.to(DEVICE)
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
img_tensor = ToTensor()(image)
_, original_image_h, original_image_w = img_tensor.shape
xy = []
for i in range(GRID_SIZE):
curr_x = 0.5 + i / GRID_SIZE * original_image_w
for j in range(GRID_SIZE):
curr_y = 0.5 + j / GRID_SIZE * original_image_h
xy.append([curr_x, curr_y])
xy = torch.from_numpy(np.array(xy))
points = xy
num_pts = xy.shape[0]
point_labels = torch.ones(num_pts, 1)
with torch.no_grad():
predicted_masks, predicted_iou = get_predictions_given_embeddings_and_queries(
img_tensor.to(DEVICE),
points.reshape(1, num_pts, 1, 2).to(DEVICE),
point_labels.reshape(1, num_pts, 1).to(DEVICE),
model.to(DEVICE),
)
rle = [mask_to_rle_pytorch(m[0:1]) for m in predicted_masks]
predicted_masks = process_small_region(rle)
return predicted_masks
def show_anns_ours(mask, ax):
ax.set_autoscale_on(False)
img = np.ones((mask[0].shape[0], mask[0].shape[1], 4))
img[:,:,3] = 0
for ann in mask:
m = ann
color_mask = np.concatenate([np.random.random(3), [0.5]])
img[m] = color_mask
ax.imshow(img)
from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits
# from squeeze_sam.build_squeeze_sam import build_squeeze_sam
import zipfile
efficient_sam_vitt_model = build_efficient_sam_vitt()
efficient_sam_vitt_model.eval()
# Since EfficientSAM-S checkpoint file is >100MB, we store the zip file.
with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
zip_ref.extractall("weights")
efficient_sam_vits_model = build_efficient_sam_vits()
efficient_sam_vits_model.eval()
fig, ax = plt.subplots(1, 3, figsize=(30, 30))
image_path = "figs/examples/dogs.jpg"
image = np.array(Image.open(image_path))
ax[0].imshow(image)
ax[0].title.set_text("Original")
ax[0].axis('off')
ax[1].imshow(image)
mask_efficient_sam_vitt = run_everything_ours(image_path, efficient_sam_vitt_model)
show_anns_ours(mask_efficient_sam_vitt, ax[1])
ax[1].title.set_text("EfficientSAM (VIT-tiny)")
ax[1].axis('off')
ax[2].imshow(image)
mask_efficient_sam_vits = run_everything_ours(image_path, efficient_sam_vits_model)
show_anns_ours(mask_efficient_sam_vits, ax[2])
ax[2].title.set_text("EfficientSAM (VIT-small)")
ax[2].axis('off')
plt.savefig("results/segmenteverything.png", bbox_inches='tight')
\ No newline at end of file
#!/bin/bash -e
# Copyright (c) Facebook, Inc. and its affiliates.
{
black --version | grep -E "23\." > /dev/null
} || {
echo "Linter requires 'black==23.*' !"
exit 1
}
ISORT_VERSION=$(isort --version-number)
if [[ "$ISORT_VERSION" != 5.12* ]]; then
echo "Linter requires isort==5.12.0 !"
exit 1
fi
echo "Running isort ..."
isort . --atomic
echo "Running black ..."
black -l 100 .
echo "Running flake8 ..."
if [ -x "$(command -v flake8)" ]; then
flake8 .
else
python3 -m flake8 .
fi
echo "Running mypy..."
mypy --exclude 'setup.py|notebooks' .
# 模型唯一标识
modelCode=638
# 模型名称
modelName=efficientsam_pytorch
# 模型描述
modelDescription=EfficientSAM基于point、box和分割一切推理
# 应用场景
appScenario=推理,制造,广媒,能源,医疗,家居,教育
# 框架类型
frameType=pytorch
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment