Unverified Commit ab22a5a5 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Export ONNX for visualization in MNIST NAS (#4473)

parent 5d197989
......@@ -181,9 +181,21 @@ The complete code of this example can be found :githublink:`here <examples/nas/m
Visualize the Experiment
------------------------
Users can visualize their experiment in the same way as visualizing a normal hyper-parameter tuning experiment. For example, open ``localhost::8081`` in your browser, 8081 is the port that you set in ``exp.run``. Please refer to `here <../Tutorial/WebUI.rst>`__ for details.
Users can visualize their experiment in the same way as visualizing a normal hyper-parameter tuning experiment. For example, open ``localhost:8081`` in your browser, 8081 is the port that you set in ``exp.run``. Please refer to `here <../Tutorial/WebUI.rst>`__ for details.
We support visualizing models with 3rd-party visualization engines (like `Netron <https://netron.app/>`__). This can be used by clicking ``Visualization`` in detail panel for each trial. Note that current visualization is based on `onnx <https://onnx.ai/>`__ , thus visualization is not feasible if the model cannot be exported into onnx. Built-in evaluators (e.g., Classification) will automatically export the model into a file. For your own evaluator, you need to save your file into ``$NNI_OUTPUT_DIR/model.onnx`` to make this work.
We support visualizing models with 3rd-party visualization engines (like `Netron <https://netron.app/>`__). This can be used by clicking ``Visualization`` in detail panel for each trial. Note that current visualization is based on `onnx <https://onnx.ai/>`__ , thus visualization is not feasible if the model cannot be exported into onnx.
Built-in evaluators (e.g., Classification) will automatically export the model into a file. For your own evaluator, you need to save your file into ``$NNI_OUTPUT_DIR/model.onnx`` to make this work. For instance,
.. code-block:: python
def evaluate_model(model_cls):
model = model_cls()
# dump the model into an onnx
if 'NNI_OUTPUT_DIR' in os.environ:
torch.onnx.export(model, (dummy_input, ),
Path(os.environ['NNI_OUTPUT_DIR']) / 'model.onnx')
# the rest are training and evaluation
Export Top Models
-----------------
......
.. 2cbe7334076be1841320c31208c338ff
.. d6ad1b913b292469c647ca68ac158840
快速入门 Retiarii
==============================
......@@ -183,9 +183,21 @@ Retiarii 提供了诸多的 `内置模型评估器 <./ModelEvaluators.rst>`__,
可视化 Experiment
------------------------
用户可以像可视化普通的超参数调优 Experiment 一样可视化他们的 Experiment 例如,在浏览器里打开 ``localhost::8081``8081 是在 ``exp.run`` 里设置的端口。 参考 `这里 <../Tutorial/WebUI.rst>`__ 了解更多细节。
用户可以像可视化普通的超参数调优 Experiment 一样可视化他们的 Experiment 例如,在浏览器里打开 ``localhost:8081``8081 是在 ``exp.run`` 里设置的端口。 参考 `这里 <../Tutorial/WebUI.rst>`__ 了解更多细节。
我们支持使用第三方工具(例如 `Netron <https://netron.app/>`__)可视化搜索过程中采样到的模型。您可以点击每个 trial 面板下的 ``Visualization``。注意,目前的可视化是基于导出成 `onnx <https://onnx.ai/>`__ 格式的模型实现的,所以如果模型无法导出成 onnx,那么可视化就无法进行。内置的模型评估器(比如 Classification)已经自动将模型导出成了一个文件。如果您自定义了模型,您需要将模型导出到 ``$NNI_OUTPUT_DIR/model.onnx``
我们支持使用第三方工具(例如 `Netron <https://netron.app/>`__)可视化搜索过程中采样到的模型。您可以点击每个 trial 面板下的 ``Visualization``。注意,目前的可视化是基于导出成 `onnx <https://onnx.ai/>`__ 格式的模型实现的,所以如果模型无法导出成 onnx,那么可视化就无法进行。
内置的模型评估器(比如 Classification)已经自动将模型导出成了一个文件。如果您自定义了模型,您需要将模型导出到 ``$NNI_OUTPUT_DIR/model.onnx``。例如,
.. code-block:: python
def evaluate_model(model_cls):
model = model_cls()
# 把模型导出成 onnx
if 'NNI_OUTPUT_DIR' in os.environ:
torch.onnx.export(model, (dummy_input, ),
Path(os.environ['NNI_OUTPUT_DIR']) / 'model.onnx')
# 剩下的就是训练和测试流程
导出最佳模型
-----------------
......
import os
import random
from pathlib import Path
import nni
import torch
......@@ -93,6 +95,11 @@ def evaluate_model(model_cls):
# "model_cls" is a class, need to instantiate
model = model_cls()
# export model for visualization
if 'NNI_OUTPUT_DIR' in os.environ:
torch.onnx.export(model, (torch.randn(1, 1, 28, 28), ),
Path(os.environ['NNI_OUTPUT_DIR']) / 'model.onnx')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
transf = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = DataLoader(MNIST('data/mnist', download=True, transform=transf), batch_size=64, shuffle=True)
......
......@@ -248,7 +248,7 @@ class Classification(Lightning):
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
train_dataloaders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
......@@ -301,7 +301,7 @@ class Regression(Lightning):
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
train_dataloaders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment