Unverified Commit 319ff036 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Allow kwargs in layer choice (#2351)

parent 80242f2b
......@@ -104,7 +104,7 @@ class BaseMutator(nn.Module):
"""
pass
def on_forward_layer_choice(self, mutable, *inputs):
def on_forward_layer_choice(self, mutable, *args, **kwargs):
"""
Callbacks of forward in LayerChoice.
......@@ -112,8 +112,10 @@ class BaseMutator(nn.Module):
----------
mutable : LayerChoice
Module whose forward is called.
inputs : list of torch.Tensor
args : list of torch.Tensor
The arguments of its forward function.
kwargs : dict
The keyword arguments of its forward function.
Returns
-------
......
......@@ -58,9 +58,6 @@ class Mutable(nn.Module):
"Or did you apply multiple fixed architectures?")
self.__dict__["mutator"] = mutator
def forward(self, *inputs):
raise NotImplementedError
@property
def key(self):
"""
......@@ -155,14 +152,14 @@ class LayerChoice(Mutable):
self.reduction = reduction
self.return_mask = return_mask
def forward(self, *inputs):
def forward(self, *args, **kwargs):
"""
Returns
-------
tuple of tensors
Output and selection mask. If ``return_mask`` is ``False``, only output is returned.
"""
out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
out, mask = self.mutator.on_forward_layer_choice(self, *args, **kwargs)
if self.return_mask:
return out, mask
return out
......
......@@ -128,7 +128,7 @@ class Mutator(BaseMutator):
result["mutable"][mutable.key].append(path)
return result
def on_forward_layer_choice(self, mutable, *inputs):
def on_forward_layer_choice(self, mutable, *args, **kwargs):
"""
On default, this method retrieves the decision obtained previously, and select certain operations.
Only operations with non-zero weight will be executed. The results will be added to a list.
......@@ -138,7 +138,9 @@ class Mutator(BaseMutator):
----------
mutable : LayerChoice
Layer choice module.
inputs : list of torch.Tensor
args : list of torch.Tensor
Inputs
kwargs : dict
Inputs
Returns
......@@ -148,16 +150,16 @@ class Mutator(BaseMutator):
"""
if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction,
[op(*inputs) for op in mutable.choices]), \
[op(*args, **kwargs) for op in mutable.choices]), \
torch.ones(mutable.length)
def _map_fn(op, *inputs):
return op(*inputs)
def _map_fn(op, args, kwargs):
return op(*args, **kwargs)
mask = self._get_decision(mutable)
assert len(mask) == len(mutable.choices), \
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable.choices))
out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask)
out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable.choices], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list):
......
......@@ -317,7 +317,7 @@ class ProxylessNasMutator(BaseMutator):
self.mutable_list.append(mutable)
mutable.registered_module = MixedOp(mutable)
def on_forward_layer_choice(self, mutable, *inputs):
def on_forward_layer_choice(self, mutable, *args, **kwargs):
"""
Callback of layer choice forward. This function defines the forward
logic of the input mutable. So mutable is only interface, its real
......@@ -327,7 +327,9 @@ class ProxylessNasMutator(BaseMutator):
----------
mutable: LayerChoice
forward logic of this input mutable
inputs: list of torch.Tensor
args: list of torch.Tensor
inputs of this mutable
kwargs: dict
inputs of this mutable
Returns
......@@ -339,7 +341,7 @@ class ProxylessNasMutator(BaseMutator):
"""
# FIXME: return mask, to be consistent with other algorithms
idx = mutable.registered_module.active_op_index
return mutable.registered_module(mutable, *inputs), idx
return mutable.registered_module(mutable, *args, **kwargs), idx
def reset_binary_gates(self):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment