Unverified Commit bddaf36d authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Tutorial of searching in DARTS space (#5053)

parent cbc6273a
...@@ -5,7 +5,7 @@ Neural Architecture Search ...@@ -5,7 +5,7 @@ Neural Architecture Search
:hidden: :hidden:
overview overview
Quickstart </tutorials/hello_nas> Tutorials <tutorials>
construct_space construct_space
exploration_strategy exploration_strategy
evaluator evaluator
......
NAS Tutorials
=============
.. toctree::
:hidden:
Hello NAS! </tutorials/hello_nas>
Search in DARTS </tutorials/darts>
This diff is collapsed.
This diff is collapsed.
240d9ba3c97be549376aa4ef2bd08344
\ No newline at end of file
This diff is collapsed.
...@@ -146,6 +146,23 @@ Tutorials ...@@ -146,6 +146,23 @@ Tutorials
</div> </div>
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, we demonstrate how to search in the famous model space proposed in `DARTS`_.">
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_darts_thumb.png
:alt: Searching in DARTS search space
:ref:`sphx_glr_tutorials_darts.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Searching in DARTS search space</div>
</div>
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Workable Pruning Process ------------------------"> <div class="sphx-glr-thumbcontainer" tooltip="Workable Pruning Process ------------------------">
...@@ -179,6 +196,7 @@ Tutorials ...@@ -179,6 +196,7 @@ Tutorials
/tutorials/nasbench_as_dataset /tutorials/nasbench_as_dataset
/tutorials/pruning_customize /tutorials/pruning_customize
/tutorials/hello_nas /tutorials/hello_nas
/tutorials/darts
/tutorials/pruning_bert_glue /tutorials/pruning_bert_glue
......
...@@ -5,8 +5,10 @@ ...@@ -5,8 +5,10 @@
Computation times Computation times
================= =================
**00:27.206** total execution time for **tutorials** files: **01:38.004** total execution time for **tutorials** files:
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_darts.py` (``darts.py``) | 01:38.004 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:27.206 | 0.0 MB | | :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:27.206 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
......
data/ data/
log/ log/
*.onnx *.onnx
lightning_logs
models/ models/
pruning_log/ pruning_log/
\ No newline at end of file
This diff is collapsed.
...@@ -260,12 +260,12 @@ class SupervisedLearningModule(LightningModule): ...@@ -260,12 +260,12 @@ class SupervisedLearningModule(LightningModule):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
def on_validation_epoch_end(self): def on_validation_epoch_end(self):
if not self.trainer.sanity_checking and self.running_mode == 'multi': if not self.trainer.sanity_checking and self.running_mode == 'multi' and nni.get_current_parameter() is not None:
# Don't report metric when sanity checking # Don't report metric when sanity checking
nni.report_intermediate_result(self._get_validation_metrics()) nni.report_intermediate_result(self._get_validation_metrics())
def on_fit_end(self): def on_fit_end(self):
if self.running_mode == 'multi': if self.running_mode == 'multi' and nni.get_current_parameter() is not None:
nni.report_final_result(self._get_validation_metrics()) nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self): def _get_validation_metrics(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment