"test/vscode:/vscode.git/clone" did not exist on "ef4459e4d82a090580b4676a4e2d3f0960ff3803"
Unverified Commit 319ff036 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Allow kwargs in layer choice (#2351)

parent 80242f2b
...@@ -104,7 +104,7 @@ class BaseMutator(nn.Module): ...@@ -104,7 +104,7 @@ class BaseMutator(nn.Module):
""" """
pass pass
def on_forward_layer_choice(self, mutable, *inputs): def on_forward_layer_choice(self, mutable, *args, **kwargs):
""" """
Callbacks of forward in LayerChoice. Callbacks of forward in LayerChoice.
...@@ -112,8 +112,10 @@ class BaseMutator(nn.Module): ...@@ -112,8 +112,10 @@ class BaseMutator(nn.Module):
---------- ----------
mutable : LayerChoice mutable : LayerChoice
Module whose forward is called. Module whose forward is called.
inputs : list of torch.Tensor args : list of torch.Tensor
The arguments of its forward function. The arguments of its forward function.
kwargs : dict
The keyword arguments of its forward function.
Returns Returns
------- -------
......
...@@ -58,9 +58,6 @@ class Mutable(nn.Module): ...@@ -58,9 +58,6 @@ class Mutable(nn.Module):
"Or did you apply multiple fixed architectures?") "Or did you apply multiple fixed architectures?")
self.__dict__["mutator"] = mutator self.__dict__["mutator"] = mutator
def forward(self, *inputs):
raise NotImplementedError
@property @property
def key(self): def key(self):
""" """
...@@ -155,14 +152,14 @@ class LayerChoice(Mutable): ...@@ -155,14 +152,14 @@ class LayerChoice(Mutable):
self.reduction = reduction self.reduction = reduction
self.return_mask = return_mask self.return_mask = return_mask
def forward(self, *inputs): def forward(self, *args, **kwargs):
""" """
Returns Returns
------- -------
tuple of tensors tuple of tensors
Output and selection mask. If ``return_mask`` is ``False``, only output is returned. Output and selection mask. If ``return_mask`` is ``False``, only output is returned.
""" """
out, mask = self.mutator.on_forward_layer_choice(self, *inputs) out, mask = self.mutator.on_forward_layer_choice(self, *args, **kwargs)
if self.return_mask: if self.return_mask:
return out, mask return out, mask
return out return out
......
...@@ -128,7 +128,7 @@ class Mutator(BaseMutator): ...@@ -128,7 +128,7 @@ class Mutator(BaseMutator):
result["mutable"][mutable.key].append(path) result["mutable"][mutable.key].append(path)
return result return result
def on_forward_layer_choice(self, mutable, *inputs): def on_forward_layer_choice(self, mutable, *args, **kwargs):
""" """
On default, this method retrieves the decision obtained previously, and select certain operations. On default, this method retrieves the decision obtained previously, and select certain operations.
Only operations with non-zero weight will be executed. The results will be added to a list. Only operations with non-zero weight will be executed. The results will be added to a list.
...@@ -138,7 +138,9 @@ class Mutator(BaseMutator): ...@@ -138,7 +138,9 @@ class Mutator(BaseMutator):
---------- ----------
mutable : LayerChoice mutable : LayerChoice
Layer choice module. Layer choice module.
inputs : list of torch.Tensor args : list of torch.Tensor
Inputs
kwargs : dict
Inputs Inputs
Returns Returns
...@@ -148,16 +150,16 @@ class Mutator(BaseMutator): ...@@ -148,16 +150,16 @@ class Mutator(BaseMutator):
""" """
if self._connect_all: if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction, return self._all_connect_tensor_reduction(mutable.reduction,
[op(*inputs) for op in mutable.choices]), \ [op(*args, **kwargs) for op in mutable.choices]), \
torch.ones(mutable.length) torch.ones(mutable.length)
def _map_fn(op, *inputs): def _map_fn(op, args, kwargs):
return op(*inputs) return op(*args, **kwargs)
mask = self._get_decision(mutable) mask = self._get_decision(mutable)
assert len(mask) == len(mutable.choices), \ assert len(mask) == len(mutable.choices), \
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable.choices)) "Invalid mask, expected {} to be of length {}.".format(mask, len(mutable.choices))
out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask) out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable.choices], mask)
return self._tensor_reduction(mutable.reduction, out), mask return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list): def on_forward_input_choice(self, mutable, tensor_list):
......
...@@ -317,7 +317,7 @@ class ProxylessNasMutator(BaseMutator): ...@@ -317,7 +317,7 @@ class ProxylessNasMutator(BaseMutator):
self.mutable_list.append(mutable) self.mutable_list.append(mutable)
mutable.registered_module = MixedOp(mutable) mutable.registered_module = MixedOp(mutable)
def on_forward_layer_choice(self, mutable, *inputs): def on_forward_layer_choice(self, mutable, *args, **kwargs):
""" """
Callback of layer choice forward. This function defines the forward Callback of layer choice forward. This function defines the forward
logic of the input mutable. So mutable is only interface, its real logic of the input mutable. So mutable is only interface, its real
...@@ -327,7 +327,9 @@ class ProxylessNasMutator(BaseMutator): ...@@ -327,7 +327,9 @@ class ProxylessNasMutator(BaseMutator):
---------- ----------
mutable: LayerChoice mutable: LayerChoice
forward logic of this input mutable forward logic of this input mutable
inputs: list of torch.Tensor args: list of torch.Tensor
inputs of this mutable
kwargs: dict
inputs of this mutable inputs of this mutable
Returns Returns
...@@ -339,7 +341,7 @@ class ProxylessNasMutator(BaseMutator): ...@@ -339,7 +341,7 @@ class ProxylessNasMutator(BaseMutator):
""" """
# FIXME: return mask, to be consistent with other algorithms # FIXME: return mask, to be consistent with other algorithms
idx = mutable.registered_module.active_op_index idx = mutable.registered_module.active_op_index
return mutable.registered_module(mutable, *inputs), idx return mutable.registered_module(mutable, *args, **kwargs), idx
def reset_binary_gates(self): def reset_binary_gates(self):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment