Unverified Commit ca980ee7 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add model wrapper to other Retiarii examples (#3838)

parent 71fc4da2
...@@ -109,7 +109,9 @@ ...@@ -109,7 +109,9 @@
"source": [ "source": [
"import torch.nn.functional as F\n", "import torch.nn.functional as F\n",
"import nni.retiarii.nn.pytorch as nn\n", "import nni.retiarii.nn.pytorch as nn\n",
"from nni.retiarii import model_wrapper\n",
"\n", "\n",
"@model_wrapper\n",
"class Net(nn.Module):\n", "class Net(nn.Module):\n",
" def __init__(self):\n", " def __init__(self):\n",
" super(Net, self).__init__()\n", " super(Net, self).__init__()\n",
...@@ -949,4 +951,4 @@ ...@@ -949,4 +951,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 4 "nbformat_minor": 4
} }
\ No newline at end of file
...@@ -127,7 +127,9 @@ ...@@ -127,7 +127,9 @@
"source": [ "source": [
"import nni.retiarii.nn.pytorch as nn\n", "import nni.retiarii.nn.pytorch as nn\n",
"import torch.nn.functional as F\n", "import torch.nn.functional as F\n",
"from nni.retiarii import model_wrapper\n",
"\n", "\n",
"@model_wrapper\n",
"class Net(nn.Module):\n", "class Net(nn.Module):\n",
"\n", "\n",
" def __init__(self, input_size):\n", " def __init__(self, input_size):\n",
...@@ -1069,4 +1071,4 @@ ...@@ -1069,4 +1071,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 4 "nbformat_minor": 4
} }
\ No newline at end of file
from collections import OrderedDict from collections import OrderedDict
from nni.retiarii.serializer import model_wrapper
from typing import (List, Optional) from typing import (List, Optional)
import torch import torch
...@@ -7,7 +8,7 @@ import torch.nn as torch_nn ...@@ -7,7 +8,7 @@ import torch.nn as torch_nn
import ops import ops
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit from nni.retiarii import basic_unit, model_wrapper
@basic_unit @basic_unit
class AuxiliaryHead(nn.Module): class AuxiliaryHead(nn.Module):
...@@ -98,6 +99,7 @@ class Cell(nn.Module): ...@@ -98,6 +99,7 @@ class Cell(nn.Module):
output = torch.cat(new_tensors, dim=1) output = torch.cat(new_tensors, dim=1)
return output return output
@model_wrapper
class CNN(nn.Module): class CNN(nn.Module):
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4, def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment