Unverified Commit 98875f52 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

add label for autoactivation (#4021)

parent fb3d4119
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
from nni.retiarii.serializer import basic_unit from nni.retiarii.serializer import basic_unit
from .api import LayerChoice from .api import LayerChoice
from .utils import generate_new_label
from ...utils import version_larger_equal from ...utils import version_larger_equal
__all__ = ['AutoActivation'] __all__ = ['AutoActivation']
...@@ -230,18 +231,23 @@ class AutoActivation(nn.Module): ...@@ -230,18 +231,23 @@ class AutoActivation(nn.Module):
unit_num : int unit_num : int
the number of core units the number of core units
""" """
def __init__(self, unit_num = 1): def __init__(self, unit_num: int = 1, label: str = None):
super().__init__() super().__init__()
self._label = generate_new_label(label)
self.unaries = nn.ModuleList() self.unaries = nn.ModuleList()
self.binaries = nn.ModuleList() self.binaries = nn.ModuleList()
self.first_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules]) self.first_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules], label = f'{self.label}__unary_0')
for _ in range(unit_num): for i in range(unit_num):
one_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules]) one_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules], label = f'{self.label}__unary_{i+1}')
self.unaries.append(one_unary) self.unaries.append(one_unary)
for _ in range(unit_num): for i in range(unit_num):
one_binary = LayerChoice([eval('{}()'.format(binary)) for binary in binary_modules]) one_binary = LayerChoice([eval('{}()'.format(binary)) for binary in binary_modules], label = f'{self.label}__binary_{i}')
self.binaries.append(one_binary) self.binaries.append(one_binary)
@property
def label(self):
return self._label
def forward(self, x): def forward(self, x):
out = self.first_unary(x) out = self.first_unary(x)
for unary, binary in zip(self.unaries, self.binaries): for unary, binary in zip(self.unaries, self.binaries):
......
...@@ -289,8 +289,8 @@ describe('Unit test for nnimanager', function () { ...@@ -289,8 +289,8 @@ describe('Unit test for nnimanager', function () {
}) })
}) })
it('test resumeExperiment', async () => { //it('test resumeExperiment', async () => {
//TODO: add resume experiment unit test //TODO: add resume experiment unit test
}) //})
}) })
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment