"src/vscode:/vscode.git/clone" did not exist on "2576647c1ffc0087bbc137e81237616f9815d51d"
Commit b0f7a242 authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #1000 canceled with stages
File added
*.pyc
*.pyo
*.pyd
__py
**/__pycache__/
# Default ignored files
/shelf/
/workspace.xml
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/venv" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.8 (efficientsam_pytorch)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (efficientsam_pytorch)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/efficientsam_pytorch.iml" filepath="$PROJECT_DIR$/.idea/efficientsam_pytorch.iml" />
</modules>
</component>
</project>
\ No newline at end of file
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk23.10-py38
COPY requirements.txt requirements.txt
RUN source /opt/dtk-23.10/env.sh
ENV LANG C.UTF-8
RUN pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits
# from squeeze_sam.build_squeeze_sam import build_squeeze_sam
from PIL import Image
from torchvision import transforms
import torch
import numpy as np
import zipfile
models = {}
# Build the EfficientSAM-Ti model.
models['efficientsam_ti'] = build_efficient_sam_vitt()
# Since EfficientSAM-S checkpoint file is >100MB, we store the zip file.
with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
zip_ref.extractall("weights")
# Build the EfficientSAM-S model.
models['efficientsam_s'] = build_efficient_sam_vits()
# Build the SqueezeSAM model.
# models['squeeze_sam'] = build_squeeze_sam()
# load an image
sample_image_np = np.array(Image.open("figs/examples/dogs.jpg"))
sample_image_tensor = transforms.ToTensor()(sample_image_np)
# Feed a few (x,y) points in the mask as input.
input_points = torch.tensor([[[[580, 350], [650, 350]]]])
input_labels = torch.tensor([[[1, 1]]])
# Run inference for both EfficientSAM-Ti and EfficientSAM-S models.
for model_name, model in models.items():
print('Running inference using ', model_name)
predicted_logits, predicted_iou = model(
sample_image_tensor[None, ...],
input_points,
input_labels,
)
sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
predicted_logits = torch.take_along_dim(
predicted_logits, sorted_ids[..., None, None], dim=2
)
# The masks are already sorted by their predicted IOUs.
# The first dimension is the batch size (we have a single image. so it is 1).
# The second dimension is the number of masks we want to generate (in this case, it is only 1)
# The third dimension is the number of candidate masks output by the model.
# For this demo we use the first mask.
mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()
masked_image_np = sample_image_np.copy().astype(np.uint8) * mask[:,:,None]
Image.fromarray(masked_image_np).save(f"figs/examples/dogs_{model_name}_mask.png")
#Onnx export code is from [labelme annotation tool](https://github.com/labelmeai/efficient-sam). Huge thanks to Kentaro Wada.
import numpy as np
import torch
import imgviz
import onnxruntime
import time
from PIL import Image
def predict_onnx(input_image, input_points, input_labels):
if 0:
inference_session = onnxruntime.InferenceSession(
"weights/efficient_sam_vitt.onnx"
)
(
predicted_logits,
predicted_iou,
predicted_lowres_logits,
) = inference_session.run(
output_names=None,
input_feed={
"batched_images": input_image,
"batched_point_coords": input_points,
"batched_point_labels": input_labels,
},
)
else:
inference_session = onnxruntime.InferenceSession(
"weights/efficient_sam_vitt_encoder.onnx"
)
t_start = time.time()
image_embeddings, = inference_session.run(
output_names=None,
input_feed={
"batched_images": input_image,
},
)
print("encoder time", time.time() - t_start)
inference_session = onnxruntime.InferenceSession(
"weights/efficient_sam_vitt_decoder.onnx"
)
t_start = time.time()
(
predicted_logits,
predicted_iou,
predicted_lowres_logits,
) = inference_session.run(
output_names=None,
input_feed={
"image_embeddings": image_embeddings,
"batched_point_coords": input_points,
"batched_point_labels": input_labels,
"orig_im_size": np.array(input_image.shape[2:], dtype=np.int64),
},
)
print("decoder time", time.time() - t_start)
mask = predicted_logits[0, 0, 0, :, :] >= 0
imgviz.io.imsave(f"figs/examples/dogs_onnx_mask.png", mask)
def main():
image = np.array(Image.open("figs/examples/dogs.jpg"))
input_image = image.transpose(2, 0, 1)[None].astype(np.float32) / 255.0
# batch_size, num_queries, num_points, 2
input_points = np.array([[[[580, 350], [650, 350]]]], dtype=np.float32)
# batch_size, num_queries, num_points
input_labels = np.array([[[1, 1]]], dtype=np.float32)
predict_onnx(input_image, input_points, input_labels)
if __name__ == "__main__":
main()
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# EfficientSAM
## 论文
`EfficientSAM: Leveraged Masked Image Pretraining for Efficient Segment Anything`
- https://arxiv.org/abs/2312.00863
## 模型结构
EfficientSAM模型利用掩码图像预训练(SAMI),该预训练学习从SAM图像编码器重构特征,以进行有效的视觉表示学习。然后采用SAMI预训练的轻量级图像编码器和掩码解码器来构建EfficientSAMs ,并在SA-1B数据集上对模型进行微调以执行分割一切的任务。EfficientSAM-S将SAM的推理时间减少了约20倍,参数大小减少了约20倍,性能下降很小。
<div align=left>
<img src="./doc/The overview.png"/>
</div>
## 算法原理
模型包含两个阶段:ImageNet上的SAMI预训练和SA-1B上的SAM微调。EfficientSAM的核心组件包括:交叉注意力解码器、线性投影头、重建损失。
交叉注意力解码器:在SAM特征监督下,解码器重构掩蔽令牌,同时编码器输出作为重构锚点。解码器查询来自掩码令牌,键和值来自编码器和未掩码特征。结合编解码器两者输出特征,用于MAE输出嵌入,并重新排序至原始图像位置。
线性投影头:将编码器和解码器输出特征输入到线性投影头,以对齐SAM图像编码器特征并解决特征维数不匹配问题。
重建损失:在每次训练迭代中,SAMI由从SAM图像编码器中提取的前馈特征,以及MAE的前馈和反向传播过程组成。比较了SAM图像编码器和MAE线性投影头的输出,计算了重建损失。
<div align=center>
<img src="./doc/EfficientSAM framework.png"/>
</div>
## 环境配置
### Docker(方法一)
此处提供[光源](https://www.sourcefind.cn/#/service-details)拉取docker镜像的地址与使用步骤
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk23.10-py38
docker run -it --shm-size=64G -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal:/opt/hyhal --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name efficientsam_pytorch <your IMAGE ID> bash # <your IMAGE ID>为以上拉取的docker的镜像ID替换,本镜像为:ffa1f63239fc
cd /path/your_code_data/efficientsam_pytorch
pip install -e .
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
git clone https://mirror.ghproxy.com/https://github.com/facebookresearch/segment-anything.git
cd segment-anything
pip install -e .
```
### Dockerfile(方法二)
此处提供dockerfile的使用方法
```
docker build --no-cache -t efficientsam:latest .
docker run -it --shm-size=64G -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal:/opt/hyhal --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name efficientsam_pytorch efficientsam bash
cd /path/your_code_data/efficientsam_pytorch
pip install -e .
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
git clone https://mirror.ghproxy.com/https://github.com/facebookresearch/segment-anything.git
cd segment-anything
pip install -e .
```
### Anaconda(方法三)
此处提供本地配置、编译的详细步骤,例如:
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```
DTK驱动:dtk23.10
python:python3.8
torch: 2.1.0
torchvision: 0.16.0
triton:2.1.0
```
`Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应`
其它依赖环境安装如下:
```
cd /path/your_code_data/efficientsam_pytorch
pip install -e .
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
git clone https://mirror.ghproxy.com/https://github.com/facebookresearch/segment-anything.git
cd segment-anything
pip install -e .
```
## 数据集
预训练阶段数据集为数据集为ImageNet
[ImageNet](https://image-net.org/)
微调阶段数据集为数据集为SA-1B
[SA-1B](https://ai.meta.com/datasets/segment-anything/)
## 训练
官方暂未开放
## 推理
模型的权重可以通过以下表格链接获得,推理时将其下载放置于weights文件夹下。
| EfficientSAM-S | EfficientSAM-Ti |
|------------------------------|------------------------------|
| [Download](https://github.com/yformer/EfficientSAM/blob/main/weights/efficient_sam_vits.pt.zip) |[Download](https://github.com/yformer/EfficientSAM/blob/main/weights/efficient_sam_vitt.pt)|
单卡推理
```
cd /path/your_code_data/efficientsam_pytorch
```
基于point和box推理,更多细节参考netbooks/EfficientSAM_example.ipynb:
基于point推理
```
python inference_point_prompt.py
```
基于box推理
```
python inference_box_prompt.py
```
segment everything推理,更多细节参考netbooks/EfficientSAM_segment_everything_example.ipynb:
```
python inference_segment_everything.py
```
## result
EfficientSAM-S和EfficientSAM-Ti 基于point测试结果如下:
<div align=center>
<img src="/results/efficientsam_point.png"/>
</div>
EfficientSAM-S和EfficientSAM-Ti 基于box测试结果如下:
<div align=center>
<img src="/results/efficientsam_box.png"/>
</div>
EfficientSAM-S和EfficientSAM-Ti segment_everything测试结果如下:
<div align=center>
<img src="/results/segmenteverything.png"/>
</div>
### 精度
## 应用场景
### 算法类别
`图像分割`
### 热点应用行业
`制造,广媒,能源,医疗,家居,教育`
## 源码仓库及问题反馈
- 此处填本项目gitlab地址
## 参考资料
- https://github.com/yformer/EfficientSAM
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
from .build_efficient_sam import (
build_efficient_sam_vitt,
build_efficient_sam_vits,
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .efficient_sam import build_efficient_sam
def build_efficient_sam_vitt():
return build_efficient_sam(
encoder_patch_embed_dim=192,
encoder_num_heads=3,
checkpoint="weights/efficient_sam_vitt.pt",
).eval()
def build_efficient_sam_vits():
return build_efficient_sam(
encoder_patch_embed_dim=384,
encoder_num_heads=6,
checkpoint="weights/efficient_sam_vits.pt",
).eval()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Any, List, Tuple, Type
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from .efficient_sam_decoder import MaskDecoder, PromptEncoder
from .efficient_sam_encoder import ImageEncoderViT
from .two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer
class EfficientSam(nn.Module):
mask_threshold: float = 0.0
image_format: str = "RGB"
def __init__(
self,
image_encoder: ImageEncoderViT,
prompt_encoder: PromptEncoder,
decoder_max_num_input_points: int,
mask_decoder: MaskDecoder,
pixel_mean: List[float] = [0.485, 0.456, 0.406],
pixel_std: List[float] = [0.229, 0.224, 0.225],
) -> None:
"""
SAM predicts object masks from an image and input prompts.
Arguments:
image_encoder (ImageEncoderViT): The backbone used to encode the
image into image embeddings that allow for efficient mask prediction.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
and encoded prompts.
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
pixel_std (list(float)): Std values for normalizing pixels in the input image.
"""
super().__init__()
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.decoder_max_num_input_points = decoder_max_num_input_points
self.mask_decoder = mask_decoder
self.register_buffer(
"pixel_mean", torch.Tensor(pixel_mean).view(1, 3, 1, 1), False
)
self.register_buffer(
"pixel_std", torch.Tensor(pixel_std).view(1, 3, 1, 1), False
)
@torch.jit.export
def predict_masks(
self,
image_embeddings: torch.Tensor,
batched_points: torch.Tensor,
batched_point_labels: torch.Tensor,
multimask_output: bool,
input_h: int,
input_w: int,
output_h: int = -1,
output_w: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predicts masks given image embeddings and prompts. This only runs the decoder.
Arguments:
image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W]
batched_points: A tensor of shape [B, max_num_queries, num_pts, 2]
batched_point_labels: A tensor of shape [B, max_num_queries, num_pts]
Returns:
A tuple of two tensors:
low_res_mask: A tensor of shape [B, max_num_queries, 256, 256] of predicted masks
iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
"""
batch_size, max_num_queries, num_pts, _ = batched_points.shape
num_pts = batched_points.shape[2]
rescaled_batched_points = self.get_rescaled_pts(batched_points, input_h, input_w)
if num_pts > self.decoder_max_num_input_points:
rescaled_batched_points = rescaled_batched_points[
:, :, : self.decoder_max_num_input_points, :
]
batched_point_labels = batched_point_labels[
:, :, : self.decoder_max_num_input_points
]
elif num_pts < self.decoder_max_num_input_points:
rescaled_batched_points = F.pad(
rescaled_batched_points,
(0, 0, 0, self.decoder_max_num_input_points - num_pts),
value=-1.0,
)
batched_point_labels = F.pad(
batched_point_labels,
(0, self.decoder_max_num_input_points - num_pts),
value=-1.0,
)
sparse_embeddings = self.prompt_encoder(
rescaled_batched_points.reshape(
batch_size * max_num_queries, self.decoder_max_num_input_points, 2
),
batched_point_labels.reshape(
batch_size * max_num_queries, self.decoder_max_num_input_points
),
)
sparse_embeddings = sparse_embeddings.view(
batch_size,
max_num_queries,
sparse_embeddings.shape[1],
sparse_embeddings.shape[2],
)
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings,
self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
multimask_output=multimask_output,
)
_, num_predictions, low_res_size, _ = low_res_masks.shape
if output_w > 0 and output_h > 0:
output_masks = F.interpolate(
low_res_masks, (output_h, output_w), mode="bicubic"
)
output_masks = torch.reshape(
output_masks,
(batch_size, max_num_queries, num_predictions, output_h, output_w),
)
else:
output_masks = torch.reshape(
low_res_masks,
(
batch_size,
max_num_queries,
num_predictions,
low_res_size,
low_res_size,
),
)
iou_predictions = torch.reshape(
iou_predictions, (batch_size, max_num_queries, num_predictions)
)
return output_masks, iou_predictions
def get_rescaled_pts(self, batched_points: torch.Tensor, input_h: int, input_w: int):
return torch.stack(
[
torch.where(
batched_points[..., 0] >= 0,
batched_points[..., 0] * self.image_encoder.img_size / input_w,
-1.0,
),
torch.where(
batched_points[..., 1] >= 0,
batched_points[..., 1] * self.image_encoder.img_size / input_h,
-1.0,
),
],
dim=-1,
)
@torch.jit.export
def get_image_embeddings(self, batched_images) -> torch.Tensor:
"""
Predicts masks end-to-end from provided images and prompts.
If prompts are not known in advance, using SamPredictor is
recommended over calling the model directly.
Arguments:
batched_images: A tensor of shape [B, 3, H, W]
Returns:
List of image embeddings each of of shape [B, C(i), H(i), W(i)].
The last embedding corresponds to the final layer.
"""
batched_images = self.preprocess(batched_images)
return self.image_encoder(batched_images)
def forward(
self,
batched_images: torch.Tensor,
batched_points: torch.Tensor,
batched_point_labels: torch.Tensor,
scale_to_original_image_size: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predicts masks end-to-end from provided images and prompts.
If prompts are not known in advance, using SamPredictor is
recommended over calling the model directly.
Arguments:
batched_images: A tensor of shape [B, 3, H, W]
batched_points: A tensor of shape [B, num_queries, max_num_pts, 2]
batched_point_labels: A tensor of shape [B, num_queries, max_num_pts]
Returns:
A list tuples of two tensors where the ith element is by considering the first i+1 points.
low_res_mask: A tensor of shape [B, 256, 256] of predicted masks
iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
"""
batch_size, _, input_h, input_w = batched_images.shape
image_embeddings = self.get_image_embeddings(batched_images)
return self.predict_masks(
image_embeddings,
batched_points,
batched_point_labels,
multimask_output=True,
input_h=input_h,
input_w=input_w,
output_h=input_h if scale_to_original_image_size else -1,
output_w=input_w if scale_to_original_image_size else -1,
)
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
if (
x.shape[2] != self.image_encoder.img_size
or x.shape[3] != self.image_encoder.img_size
):
x = F.interpolate(
x,
(self.image_encoder.img_size, self.image_encoder.img_size),
mode="bilinear",
)
return (x - self.pixel_mean) / self.pixel_std
def build_efficient_sam(encoder_patch_embed_dim, encoder_num_heads, checkpoint=None):
img_size = 1024
encoder_patch_size = 16
encoder_depth = 12
encoder_mlp_ratio = 4.0
encoder_neck_dims = [256, 256]
decoder_max_num_input_points = 6
decoder_transformer_depth = 2
decoder_transformer_mlp_dim = 2048
decoder_num_heads = 8
decoder_upscaling_layer_dims = [64, 32]
num_multimask_outputs = 3
iou_head_depth = 3
iou_head_hidden_dim = 256
activation = "gelu"
normalization_type = "layer_norm"
normalize_before_activation = False
assert activation == "relu" or activation == "gelu"
if activation == "relu":
activation_fn = nn.ReLU
else:
activation_fn = nn.GELU
image_encoder = ImageEncoderViT(
img_size=img_size,
patch_size=encoder_patch_size,
in_chans=3,
patch_embed_dim=encoder_patch_embed_dim,
normalization_type=normalization_type,
depth=encoder_depth,
num_heads=encoder_num_heads,
mlp_ratio=encoder_mlp_ratio,
neck_dims=encoder_neck_dims,
act_layer=activation_fn,
)
image_embedding_size = image_encoder.image_embedding_size
encoder_transformer_output_dim = image_encoder.transformer_output_dim
sam = EfficientSam(
image_encoder=image_encoder,
prompt_encoder=PromptEncoder(
embed_dim=encoder_transformer_output_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(img_size, img_size),
),
decoder_max_num_input_points=decoder_max_num_input_points,
mask_decoder=MaskDecoder(
transformer_dim=encoder_transformer_output_dim,
transformer=TwoWayTransformer(
depth=decoder_transformer_depth,
embedding_dim=encoder_transformer_output_dim,
num_heads=decoder_num_heads,
mlp_dim=decoder_transformer_mlp_dim,
activation=activation_fn,
normalize_before_activation=normalize_before_activation,
),
num_multimask_outputs=num_multimask_outputs,
activation=activation_fn,
normalization_type=normalization_type,
normalize_before_activation=normalize_before_activation,
iou_head_depth=iou_head_depth - 1,
iou_head_hidden_dim=iou_head_hidden_dim,
upscaling_layer_dims=decoder_upscaling_layer_dims,
),
pixel_mean=[0.485, 0.456, 0.406],
pixel_std=[0.229, 0.224, 0.225],
)
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f, map_location="cpu")
sam.load_state_dict(state_dict["model"])
return sam
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Tuple, Type
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mlp import MLPBlock
class PromptEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
image_embedding_size: Tuple[int, int],
input_image_size: Tuple[int, int],
) -> None:
"""
Encodes prompts for input to SAM's mask decoder.
Arguments:
embed_dim (int): The prompts' embedding dimension
image_embedding_size (tuple(int, int)): The spatial size of the
image embedding, as (H, W).
input_image_size (int): The padded size of the image as input
to the image encoder, as (H, W).
"""
super().__init__()
self.embed_dim = embed_dim
self.input_image_size = input_image_size
self.image_embedding_size = image_embedding_size
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.invalid_points = nn.Embedding(1, embed_dim)
self.point_embeddings = nn.Embedding(1, embed_dim)
self.bbox_top_left_embeddings = nn.Embedding(1, embed_dim)
self.bbox_bottom_right_embeddings = nn.Embedding(1, embed_dim)
def get_dense_pe(self) -> torch.Tensor:
"""
Returns the positional encoding used to encode point prompts,
applied to a dense set of points the shape of the image encoding.
Returns:
torch.Tensor: Positional encoding with shape
1x(embed_dim)x(embedding_h)x(embedding_w)
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
point_embedding = self.pe_layer.forward_with_coords(
points, self.input_image_size
)
invalid_label_ids = torch.eq(labels, -1)[:,:,None]
point_label_ids = torch.eq(labels, 1)[:,:,None]
topleft_label_ids = torch.eq(labels, 2)[:,:,None]
bottomright_label_ids = torch.eq(labels, 3)[:,:,None]
point_embedding = point_embedding + self.invalid_points.weight[:,None,:] * invalid_label_ids
point_embedding = point_embedding + self.point_embeddings.weight[:,None,:] * point_label_ids
point_embedding = point_embedding + self.bbox_top_left_embeddings.weight[:,None,:] * topleft_label_ids
point_embedding = point_embedding + self.bbox_bottom_right_embeddings.weight[:,None,:] * bottomright_label_ids
return point_embedding
def forward(
self,
coords,
labels,
) -> torch.Tensor:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Arguments:
points: A tensor of shape [B, 2]
labels: An integer tensor of shape [B] where each element is 1,2 or 3.
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
"""
return self._embed_points(coords, labels)
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, num_pos_feats: int) -> None:
super().__init__()
self.register_buffer(
"positional_encoding_gaussian_matrix", torch.randn((2, num_pos_feats))
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device = self.positional_encoding_gaussian_matrix.device
grid = torch.ones([h, w], device=device, dtype=torch.float32)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / h
x_embed = x_embed / w
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
return pe.permute(2, 0, 1) # C x H x W
def forward_with_coords(
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
coords = coords_input.clone()
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
return self._pe_encoding(coords.to(torch.float)) # B x N x C
class MaskDecoder(nn.Module):
def __init__(
self,
*,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int,
activation: Type[nn.Module],
normalization_type: str,
normalize_before_activation: bool,
iou_head_depth: int,
iou_head_hidden_dim: int,
upscaling_layer_dims: List[int],
) -> None:
"""
Predicts masks given an image and prompt embeddings, using a
transformer architecture.
Arguments:
transformer_dim (int): the channel dimension of the transformer
transformer (nn.Module): the transformer used to predict masks
num_multimask_outputs (int): the number of masks to predict
when disambiguating masks
activation (nn.Module): the type of activation to use when
upscaling masks
iou_head_depth (int): the depth of the MLP used to predict
mask quality
iou_head_hidden_dim (int): the hidden dimension of the MLP
used to predict mask quality
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
if num_multimask_outputs > 1:
self.num_mask_tokens = num_multimask_outputs + 1
else:
self.num_mask_tokens = 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
output_dim_after_upscaling = transformer_dim
self.final_output_upscaling_layers = nn.ModuleList([])
for idx, layer_dims in enumerate(upscaling_layer_dims):
self.final_output_upscaling_layers.append(
nn.Sequential(
nn.ConvTranspose2d(
output_dim_after_upscaling,
layer_dims,
kernel_size=2,
stride=2,
),
nn.GroupNorm(1, layer_dims)
if idx < len(upscaling_layer_dims) - 1
else nn.Identity(),
activation(),
)
)
output_dim_after_upscaling = layer_dims
self.output_hypernetworks_mlps = nn.ModuleList(
[
MLPBlock(
input_dim=transformer_dim,
hidden_dim=transformer_dim,
output_dim=output_dim_after_upscaling,
num_layers=2,
act=activation,
)
for i in range(self.num_mask_tokens)
]
)
self.iou_prediction_head = MLPBlock(
input_dim=transformer_dim,
hidden_dim=iou_head_hidden_dim,
output_dim=self.num_mask_tokens,
num_layers=iou_head_depth,
act=activation,
)
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W]
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings (the batch dimension is broadcastable).
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
"""
(
batch_size,
max_num_queries,
sparse_embed_dim_1,
sparse_embed_dim_2,
) = sparse_prompt_embeddings.shape
(
_,
image_embed_dim_c,
image_embed_dim_h,
image_embed_dim_w,
) = image_embeddings.shape
# Tile the image embedding for all queries.
image_embeddings_tiled = torch.tile(
image_embeddings[:, None, :, :, :], [1, max_num_queries, 1, 1, 1]
).view(
batch_size * max_num_queries,
image_embed_dim_c,
image_embed_dim_h,
image_embed_dim_w,
)
sparse_prompt_embeddings = sparse_prompt_embeddings.reshape(
batch_size * max_num_queries, sparse_embed_dim_1, sparse_embed_dim_2
)
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings_tiled,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
)
if multimask_output and self.num_multimask_outputs > 1:
return masks[:, 1:, :], iou_pred[:, 1:]
else:
return masks[:, :1, :], iou_pred[:, :1]
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
output_tokens = torch.cat(
[self.iou_token.weight, self.mask_tokens.weight], dim=0
)
output_tokens = output_tokens.unsqueeze(0).expand(
sparse_prompt_embeddings.size(0), -1, -1
)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = image_embeddings.shape
hs, src = self.transformer(image_embeddings, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
upscaled_embedding = src.transpose(1, 2).view(b, c, h, w)
for upscaling_layer in self.final_output_upscaling_layers:
upscaled_embedding = upscaling_layer(upscaled_embedding)
hyper_in_list: List[torch.Tensor] = []
for i, output_hypernetworks_mlp in enumerate(self.output_hypernetworks_mlps):
hyper_in_list.append(output_hypernetworks_mlp(mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import List, Optional, Tuple, Type
import torch
import torch.nn as nn
import torch.nn.functional as F
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
img_size,
patch_size,
in_chans,
embed_dim,
):
super().__init__()
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
bias=True,
)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads,
qkv_bias,
qk_scale=None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
act_layer=nn.GELU,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
@torch.jit.export
def get_abs_pos(
abs_pos: torch.Tensor, has_cls_token: bool, hw: List[int]
) -> torch.Tensor:
"""
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
dimension for the original embeddings.
Args:
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
hw (Tuple): size of input image tokens.
Returns:
Absolute positional embeddings after processing with shape (1, H, W, C)
"""
h = hw[0]
w = hw[1]
if has_cls_token:
abs_pos = abs_pos[:, 1:]
xy_num = abs_pos.shape[1]
size = int(math.sqrt(xy_num))
assert size * size == xy_num
if size != h or size != w:
new_abs_pos = F.interpolate(
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
size=(h, w),
mode="bicubic",
align_corners=False,
)
return new_abs_pos.permute(0, 2, 3, 1)
else:
return abs_pos.reshape(1, h, w, -1)
# Image encoder for efficient SAM.
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
patch_embed_dim: int,
normalization_type: str,
depth: int,
num_heads: int,
mlp_ratio: float,
neck_dims: List[int],
act_layer: Type[nn.Module],
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
patch_embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
act_layer (nn.Module): Activation layer.
"""
super().__init__()
self.img_size = img_size
self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1))
self.transformer_output_dim = ([patch_embed_dim] + neck_dims)[-1]
self.pretrain_use_cls_token = True
pretrain_img_size = 224
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, patch_embed_dim)
# Initialize absolute positional embedding with pretrain image size.
num_patches = (pretrain_img_size // patch_size) * (
pretrain_img_size // patch_size
)
num_positions = num_patches + 1
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, patch_embed_dim))
self.blocks = nn.ModuleList()
for i in range(depth):
vit_block = Block(patch_embed_dim, num_heads, mlp_ratio, True)
self.blocks.append(vit_block)
self.neck = nn.Sequential(
nn.Conv2d(
patch_embed_dim,
neck_dims[0],
kernel_size=1,
bias=False,
),
LayerNorm2d(neck_dims[0]),
nn.Conv2d(
neck_dims[0],
neck_dims[0],
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(neck_dims[0]),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert (
x.shape[2] == self.img_size and x.shape[3] == self.img_size
), "input image size must match self.img_size"
x = self.patch_embed(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
x = x + get_abs_pos(
self.pos_embed, self.pretrain_use_cls_token, [x.shape[1], x.shape[2]]
)
num_patches = x.shape[1]
assert x.shape[2] == num_patches
x = x.reshape(x.shape[0], num_patches * num_patches, x.shape[3])
for blk in self.blocks:
x = blk(x)
x = x.reshape(x.shape[0], num_patches, num_patches, x.shape[2])
x = self.neck(x.permute(0, 3, 1, 2))
return x
from typing import Type
from torch import nn
# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class MLPBlock(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
act: Type[nn.Module],
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Sequential(nn.Linear(n, k), act())
for n, k in zip([input_dim] + h, [hidden_dim] * num_layers)
)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return self.fc(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment