Unverified Commit 98875f52 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

add label for autoactivation (#4021)

parent fb3d4119
......@@ -7,6 +7,7 @@ import torch.nn as nn
from nni.retiarii.serializer import basic_unit
from .api import LayerChoice
from .utils import generate_new_label
from ...utils import version_larger_equal
__all__ = ['AutoActivation']
......@@ -230,18 +231,23 @@ class AutoActivation(nn.Module):
unit_num : int
the number of core units
"""
def __init__(self, unit_num = 1):
def __init__(self, unit_num: int = 1, label: str = None):
super().__init__()
self._label = generate_new_label(label)
self.unaries = nn.ModuleList()
self.binaries = nn.ModuleList()
self.first_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules])
for _ in range(unit_num):
one_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules])
self.first_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules], label = f'{self.label}__unary_0')
for i in range(unit_num):
one_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules], label = f'{self.label}__unary_{i+1}')
self.unaries.append(one_unary)
for _ in range(unit_num):
one_binary = LayerChoice([eval('{}()'.format(binary)) for binary in binary_modules])
for i in range(unit_num):
one_binary = LayerChoice([eval('{}()'.format(binary)) for binary in binary_modules], label = f'{self.label}__binary_{i}')
self.binaries.append(one_binary)
@property
def label(self):
return self._label
def forward(self, x):
out = self.first_unary(x)
for unary, binary in zip(self.unaries, self.binaries):
......
......@@ -289,8 +289,8 @@ describe('Unit test for nnimanager', function () {
})
})
it('test resumeExperiment', async () => {
//it('test resumeExperiment', async () => {
//TODO: add resume experiment unit test
})
//})
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment