model.py 370 Bytes
Newer Older
1
import nni.retiarii.nn.pytorch as nn
2
from nni.retiarii import basic_unit
3
4


5
@basic_unit
6
7
8
9
10
11
12
13
class ImportTest(nn.Module):
    def __init__(self, foo, bar):
        super().__init__()
        self.foo = nn.Linear(foo, 3)
        self.bar = nn.Dropout(bar)

    def __eq__(self, other):
        return self.foo.in_features == other.foo.in_features and self.bar.p == other.bar.p