Unverified Commit ca980ee7 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add model wrapper to other Retiarii examples (#3838)

parent 71fc4da2
......@@ -109,7 +109,9 @@
"source": [
"import torch.nn.functional as F\n",
"import nni.retiarii.nn.pytorch as nn\n",
"from nni.retiarii import model_wrapper\n",
"\n",
"@model_wrapper\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
......
......@@ -127,7 +127,9 @@
"source": [
"import nni.retiarii.nn.pytorch as nn\n",
"import torch.nn.functional as F\n",
"from nni.retiarii import model_wrapper\n",
"\n",
"@model_wrapper\n",
"class Net(nn.Module):\n",
"\n",
" def __init__(self, input_size):\n",
......
from collections import OrderedDict
from nni.retiarii.serializer import model_wrapper
from typing import (List, Optional)
import torch
......@@ -7,7 +8,7 @@ import torch.nn as torch_nn
import ops
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
from nni.retiarii import basic_unit, model_wrapper
@basic_unit
class AuxiliaryHead(nn.Module):
......@@ -98,6 +99,7 @@ class Cell(nn.Module):
output = torch.cat(new_tensors, dim=1)
return output
@model_wrapper
class CNN(nn.Module):
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment