Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
ca980ee7
Unverified
Commit
ca980ee7
authored
Jun 18, 2021
by
Yuge Zhang
Committed by
GitHub
Jun 18, 2021
Browse files
Add model wrapper to other Retiarii examples (#3838)
parent
71fc4da2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
3 deletions
+9
-3
examples/notebooks/Retiarii_example_multi-trial_NAS.ipynb
examples/notebooks/Retiarii_example_multi-trial_NAS.ipynb
+3
-1
examples/notebooks/tabular_data_classification_in_AML.ipynb
examples/notebooks/tabular_data_classification_in_AML.ipynb
+3
-1
test/retiarii_test/darts/darts_model.py
test/retiarii_test/darts/darts_model.py
+3
-1
No files found.
examples/notebooks/Retiarii_example_multi-trial_NAS.ipynb
View file @
ca980ee7
...
@@ -109,7 +109,9 @@
...
@@ -109,7 +109,9 @@
"source": [
"source": [
"import torch.nn.functional as F\n",
"import torch.nn.functional as F\n",
"import nni.retiarii.nn.pytorch as nn\n",
"import nni.retiarii.nn.pytorch as nn\n",
"from nni.retiarii import model_wrapper\n",
"\n",
"\n",
"@model_wrapper\n",
"class Net(nn.Module):\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" super(Net, self).__init__()\n",
...
@@ -949,4 +951,4 @@
...
@@ -949,4 +951,4 @@
},
},
"nbformat": 4,
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 4
}
}
\ No newline at end of file
examples/notebooks/tabular_data_classification_in_AML.ipynb
View file @
ca980ee7
...
@@ -127,7 +127,9 @@
...
@@ -127,7 +127,9 @@
"source": [
"source": [
"import nni.retiarii.nn.pytorch as nn\n",
"import nni.retiarii.nn.pytorch as nn\n",
"import torch.nn.functional as F\n",
"import torch.nn.functional as F\n",
"from nni.retiarii import model_wrapper\n",
"\n",
"\n",
"@model_wrapper\n",
"class Net(nn.Module):\n",
"class Net(nn.Module):\n",
"\n",
"\n",
" def __init__(self, input_size):\n",
" def __init__(self, input_size):\n",
...
@@ -1069,4 +1071,4 @@
...
@@ -1069,4 +1071,4 @@
},
},
"nbformat": 4,
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 4
}
}
\ No newline at end of file
test/retiarii_test/darts/darts_model.py
View file @
ca980ee7
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
nni.retiarii.serializer
import
model_wrapper
from
typing
import
(
List
,
Optional
)
from
typing
import
(
List
,
Optional
)
import
torch
import
torch
...
@@ -7,7 +8,7 @@ import torch.nn as torch_nn
...
@@ -7,7 +8,7 @@ import torch.nn as torch_nn
import
ops
import
ops
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
basic_unit
from
nni.retiarii
import
basic_unit
,
model_wrapper
@
basic_unit
@
basic_unit
class
AuxiliaryHead
(
nn
.
Module
):
class
AuxiliaryHead
(
nn
.
Module
):
...
@@ -98,6 +99,7 @@ class Cell(nn.Module):
...
@@ -98,6 +99,7 @@ class Cell(nn.Module):
output
=
torch
.
cat
(
new_tensors
,
dim
=
1
)
output
=
torch
.
cat
(
new_tensors
,
dim
=
1
)
return
output
return
output
@
model_wrapper
class
CNN
(
nn
.
Module
):
class
CNN
(
nn
.
Module
):
def
__init__
(
self
,
input_size
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
def
__init__
(
self
,
input_size
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment