Unverified Commit 18962129 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add license header and typehints for NAS (#4774)

parent 8c2f717d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy import copy
import warnings import warnings
from typing import Callable, Dict, List, Union, Optional, Tuple from typing import Callable, Dict, List, Union, Optional, Tuple, Sequence, cast
try: try:
from typing import Literal from typing import Literal
except ImportError: except ImportError:
...@@ -193,8 +196,10 @@ class Cell(nn.Module): ...@@ -193,8 +196,10 @@ class Cell(nn.Module):
def __init__(self, def __init__(self,
op_candidates: Union[ op_candidates: Union[
Callable[[], List[nn.Module]], Callable[[], List[nn.Module]],
List[Union[nn.Module, _cell_op_factory_type]], List[nn.Module],
Dict[str, Union[nn.Module, _cell_op_factory_type]] List[_cell_op_factory_type],
Dict[str, nn.Module],
Dict[str, _cell_op_factory_type]
], ],
num_nodes: int, num_nodes: int,
num_ops_per_node: int = 1, num_ops_per_node: int = 1,
...@@ -251,8 +256,8 @@ class Cell(nn.Module): ...@@ -251,8 +256,8 @@ class Cell(nn.Module):
ops = self._convert_op_candidates(op_candidates, i, k, chosen) ops = self._convert_op_candidates(op_candidates, i, k, chosen)
# though it's layer choice and input choice here, in fixed mode, the chosen module will be created. # though it's layer choice and input choice here, in fixed mode, the chosen module will be created.
self.ops[-1].append(LayerChoice(ops, label=f'{self.label}/op_{i}_{k}')) cast(ModuleList, self.ops[-1]).append(LayerChoice(ops, label=f'{self.label}/op_{i}_{k}'))
self.inputs[-1].append(inp) cast(ModuleList, self.inputs[-1]).append(inp)
@property @property
def label(self): def label(self):
...@@ -274,13 +279,17 @@ class Cell(nn.Module): ...@@ -274,13 +279,17 @@ class Cell(nn.Module):
By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``) By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``)
of some of (possibly all) the nodes' outputs in the cell. of some of (possibly all) the nodes' outputs in the cell.
""" """
processed_inputs: List[torch.Tensor]
if len(inputs) == 1 and isinstance(inputs[0], list): if len(inputs) == 1 and isinstance(inputs[0], list):
inputs = inputs[0] processed_inputs = list(inputs[0]) # shallow copy
else: else:
inputs = list(inputs) processed_inputs = cast(List[torch.Tensor], list(inputs))
assert len(inputs) == self.num_predecessors, 'The number of inputs must be equal to `num_predecessors`.' assert len(processed_inputs) == self.num_predecessors, 'The number of inputs must be equal to `num_predecessors`.'
states = self.preprocessor(inputs) states: List[torch.Tensor] = self.preprocessor(processed_inputs)
for ops, inps in zip(self.ops, self.inputs): for ops, inps in zip(
cast(Sequence[Sequence[LayerChoice]], self.ops),
cast(Sequence[Sequence[InputChoice]], self.inputs)
):
current_state = [] current_state = []
for op, inp in zip(ops, inps): for op, inp in zip(ops, inps):
current_state.append(op(inp(states))) current_state.append(op(inp(states)))
...@@ -291,7 +300,7 @@ class Cell(nn.Module): ...@@ -291,7 +300,7 @@ class Cell(nn.Module):
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim) this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
else: else:
this_cell = torch.cat([states[k] for k in self.output_node_indices], self.concat_dim) this_cell = torch.cat([states[k] for k in self.output_node_indices], self.concat_dim)
return self.postprocessor(this_cell, inputs) return self.postprocessor(this_cell, processed_inputs)
@staticmethod @staticmethod
def _convert_op_candidates(op_candidates, node_index, op_index, chosen) -> Union[Dict[str, nn.Module], List[nn.Module]]: def _convert_op_candidates(op_candidates, node_index, op_index, chosen) -> Union[Dict[str, nn.Module], List[nn.Module]]:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy import copy
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Callable, List, Union, Tuple, Optional from typing import Callable, List, Dict, Union, Tuple, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL from nni.retiarii.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL
from .api import LayerChoice, ValueChoice, ValueChoiceX from .api import LayerChoice, ValueChoice, ValueChoiceX, ChoiceOf
from .cell import Cell from .cell import Cell
from .nasbench101 import NasBench101Cell, NasBench101Mutator from .nasbench101 import NasBench101Cell, NasBench101Mutator
from .mutation_utils import Mutable, generate_new_label, get_fixed_value from .mutation_utils import Mutable, generate_new_label, get_fixed_value
...@@ -64,7 +67,7 @@ class Repeat(Mutable): ...@@ -64,7 +67,7 @@ class Repeat(Mutable):
List[Callable[[int], nn.Module]], List[Callable[[int], nn.Module]],
nn.Module, nn.Module,
List[nn.Module]], List[nn.Module]],
depth: Union[int, Tuple[int, int], ValueChoice], *, label: Optional[str] = None): depth: Union[int, Tuple[int, int], ChoiceOf[int]], *, label: Optional[str] = None):
if isinstance(depth, tuple): if isinstance(depth, tuple):
# we can't create a value choice here, # we can't create a value choice here,
# otherwise we will have two value choices, one created here, another in init. # otherwise we will have two value choices, one created here, another in init.
...@@ -90,7 +93,7 @@ class Repeat(Mutable): ...@@ -90,7 +93,7 @@ class Repeat(Mutable):
List[Callable[[int], nn.Module]], List[Callable[[int], nn.Module]],
nn.Module, nn.Module,
List[nn.Module]], List[nn.Module]],
depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None): depth: Union[int, Tuple[int, int], ChoiceOf[int]], *, label: Optional[str] = None):
super().__init__() super().__init__()
self._label = None # by default, no label self._label = None # by default, no label
...@@ -192,7 +195,7 @@ class NasBench201Cell(nn.Module): ...@@ -192,7 +195,7 @@ class NasBench201Cell(nn.Module):
return OrderedDict([(str(i), t) for i, t in enumerate(x)]) return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x) return OrderedDict(x)
def __init__(self, op_candidates: List[Callable[[int, int], nn.Module]], def __init__(self, op_candidates: Union[Dict[str, Callable[[int, int], nn.Module]], List[Callable[[int, int], nn.Module]]],
in_features: int, out_features: int, num_tensors: int = 4, in_features: int, out_features: int, num_tensors: int = 4,
label: Optional[str] = None): label: Optional[str] = None):
super().__init__() super().__init__()
...@@ -214,16 +217,15 @@ class NasBench201Cell(nn.Module): ...@@ -214,16 +217,15 @@ class NasBench201Cell(nn.Module):
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
self.layers.append(node_ops) self.layers.append(node_ops)
def forward(self, inputs): def forward(self, inputs: torch.Tensor) -> torch.Tensor:
""" """
The forward of input choice is simply selecting first on all choices. The forward of input choice is simply selecting first on all choices.
It shouldn't be called directly by users in most cases. It shouldn't be called directly by users in most cases.
""" """
tensors = [inputs] tensors: List[torch.Tensor] = [inputs]
for layer in self.layers: for layer in self.layers:
current_tensor = [] current_tensor: List[torch.Tensor] = []
for i, op in enumerate(layer): for i, op in enumerate(layer): # type: ignore
current_tensor.append(op(tensors[i])) current_tensor.append(op(tensors[i])) # type: ignore
current_tensor = torch.sum(torch.stack(current_tensor), 0) tensors.append(torch.sum(torch.stack(current_tensor), 0))
tensors.append(current_tensor)
return tensors[-1] return tensors[-1]
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
from packaging.version import Version from packaging.version import Version
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -233,7 +235,7 @@ class AutoActivation(nn.Module): ...@@ -233,7 +235,7 @@ class AutoActivation(nn.Module):
----- -----
Current `beta` is not per-channel parameter. Current `beta` is not per-channel parameter.
""" """
def __init__(self, unit_num: int = 1, label: str = None): def __init__(self, unit_num: int = 1, label: str | None = None):
super().__init__() super().__init__()
self._label = generate_new_label(label) self._label = generate_new_label(label)
self.unaries = nn.ModuleList() self.unaries = nn.ModuleList()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Optional, Tuple, Union from typing import Any, Optional, Tuple, Union
import torch.nn as nn import torch.nn as nn
...@@ -41,7 +44,7 @@ def generate_new_label(label: Optional[str]): ...@@ -41,7 +44,7 @@ def generate_new_label(label: Optional[str]):
return label return label
def get_fixed_value(label: str) -> Any: def get_fixed_value(label: Optional[str]) -> Any:
ret = get_current_context('fixed') ret = get_current_context('fixed')
try: try:
return ret[generate_new_label(label)] return ret[generate_new_label(label)]
...@@ -49,7 +52,7 @@ def get_fixed_value(label: str) -> Any: ...@@ -49,7 +52,7 @@ def get_fixed_value(label: str) -> Any:
raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}') raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}')
def get_fixed_dict(label_prefix: str) -> Tuple[str, Any]: def get_fixed_dict(label_prefix: Optional[str]) -> Tuple[str, Any]:
ret = get_current_context('fixed') ret = get_current_context('fixed')
try: try:
label_prefix = generate_new_label(label_prefix) label_prefix = generate_new_label(label_prefix)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import inspect import inspect
from typing import Any, List, Optional, Tuple, Dict, Iterator from typing import Any, List, Optional, Tuple, Dict, Iterator, Iterable, cast
import torch.nn as nn import torch.nn as nn
...@@ -28,12 +28,14 @@ class LayerChoiceMutator(Mutator): ...@@ -28,12 +28,14 @@ class LayerChoiceMutator(Mutator):
# Each layer choice corresponds to a cell, which is unconnected in the base graph. # Each layer choice corresponds to a cell, which is unconnected in the base graph.
# We add the connections here in the mutation logic. # We add the connections here in the mutation logic.
# Thus, the mutated model should not be mutated again. Everything should be based on the original base graph. # Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
target = model.graphs[node.operation.cell_name] target = model.graphs[cast(Cell, node.operation).cell_name]
chosen_node = target.get_node_by_name(chosen) chosen_node = target.get_node_by_name(chosen)
assert chosen_node is not None assert chosen_node is not None
target.add_edge((target.input_node, 0), (chosen_node, None)) target.add_edge((target.input_node, 0), (chosen_node, None))
target.add_edge((chosen_node, None), (target.output_node, None)) target.add_edge((chosen_node, None), (target.output_node, None))
model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name)) operation = cast(Cell, node.operation)
target_node = cast(Node, model.get_node_by_name(node.name))
target_node.update_operation(Cell(operation.cell_name))
# remove redundant nodes # remove redundant nodes
for rm_node in list(target.hidden_nodes): # remove from a list on the fly will cause issues for rm_node in list(target.hidden_nodes): # remove from a list on the fly will cause issues
...@@ -57,7 +59,7 @@ class InputChoiceMutator(Mutator): ...@@ -57,7 +59,7 @@ class InputChoiceMutator(Mutator):
else: else:
chosen = [self.choice(candidates) for _ in range(n_chosen)] chosen = [self.choice(candidates) for _ in range(n_chosen)]
for node in self.nodes: for node in self.nodes:
target = model.get_node_by_name(node.name) target = cast(Node, model.get_node_by_name(node.name))
target.update_operation('__torch__.nni.retiarii.nn.pytorch.ChosenInputs', target.update_operation('__torch__.nni.retiarii.nn.pytorch.ChosenInputs',
{'chosen': chosen, 'reduction': node.operation.parameters['reduction']}) {'chosen': chosen, 'reduction': node.operation.parameters['reduction']})
...@@ -74,7 +76,7 @@ class ValueChoiceMutator(Mutator): ...@@ -74,7 +76,7 @@ class ValueChoiceMutator(Mutator):
# no need to support transformation here, # no need to support transformation here,
# because it is naturally done in forward loop # because it is naturally done in forward loop
for node in self.nodes: for node in self.nodes:
target = model.get_node_by_name(node.name) target = cast(Node, model.get_node_by_name(node.name))
target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen}) target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen})
...@@ -86,7 +88,7 @@ class ParameterChoiceLeafMutator(Mutator): ...@@ -86,7 +88,7 @@ class ParameterChoiceLeafMutator(Mutator):
super().__init__(label=label) super().__init__(label=label)
self.candidates = candidates self.candidates = candidates
def mutate(self, model: Model) -> Model: def mutate(self, model: Model) -> None:
# leave a record here # leave a record here
# real mutations will be done in ParameterChoiceMutator # real mutations will be done in ParameterChoiceMutator
self.choice(self.candidates) self.choice(self.candidates)
...@@ -103,7 +105,7 @@ class ParameterChoiceMutator(Mutator): ...@@ -103,7 +105,7 @@ class ParameterChoiceMutator(Mutator):
self.nodes = nodes self.nodes = nodes
def mutate(self, model: Model) -> Model: def mutate(self, model: Model) -> None:
# looks like {"label1": "cat", "label2": 123} # looks like {"label1": "cat", "label2": 123}
value_choice_decisions = {} value_choice_decisions = {}
for mutation in model.history: for mutation in model.history:
...@@ -122,7 +124,7 @@ class ParameterChoiceMutator(Mutator): ...@@ -122,7 +124,7 @@ class ParameterChoiceMutator(Mutator):
result_value = value_choice.evaluate(leaf_node_values) result_value = value_choice.evaluate(leaf_node_values)
# update model with graph mutation primitives # update model with graph mutation primitives
target = model.get_node_by_name(node.name) target = cast(Node, model.get_node_by_name(node.name))
target.update_operation(target.operation.type, {**target.operation.parameters, argname: result_value}) target.update_operation(target.operation.type, {**target.operation.parameters, argname: result_value})
...@@ -138,20 +140,20 @@ class RepeatMutator(Mutator): ...@@ -138,20 +140,20 @@ class RepeatMutator(Mutator):
while u != graph.output_node: while u != graph.output_node:
if u != graph.input_node: if u != graph.input_node:
chain.append(u) chain.append(u)
assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successor}.' assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successors}.'
u = u.successors[0] u = u.successors[0]
return chain return chain
def mutate(self, model): def mutate(self, model):
for node in self.nodes: for node in self.nodes:
# the logic here is similar to layer choice. We find cell attached to each node. # the logic here is similar to layer choice. We find cell attached to each node.
target: Graph = model.graphs[node.operation.cell_name] target: Graph = model.graphs[cast(Cell, node.operation).cell_name]
chain = self._retrieve_chain_from_graph(target) chain = self._retrieve_chain_from_graph(target)
# and we get the chosen depth (by value choice) # and we get the chosen depth (by value choice)
node_in_model = model.get_node_by_name(node.name) node_in_model = cast(Node, model.get_node_by_name(node.name))
# depth is a value choice in base model # depth is a value choice in base model
# but it's already mutated by a ParameterChoiceMutator here # but it's already mutated by a ParameterChoiceMutator here
chosen_depth = node_in_model.operation.parameters['depth'] chosen_depth: int = node_in_model.operation.parameters['depth']
for edge in chain[chosen_depth - 1].outgoing_edges: for edge in chain[chosen_depth - 1].outgoing_edges:
edge.remove() edge.remove()
target.add_edge((chain[chosen_depth - 1], None), (target.output_node, None)) target.add_edge((chain[chosen_depth - 1], None), (target.output_node, None))
...@@ -159,8 +161,11 @@ class RepeatMutator(Mutator): ...@@ -159,8 +161,11 @@ class RepeatMutator(Mutator):
for edge in rm_node.outgoing_edges: for edge in rm_node.outgoing_edges:
edge.remove() edge.remove()
rm_node.remove() rm_node.remove()
# to delete the unused parameters. # to delete the unused parameters.
model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name)) target_node = cast(Node, model.get_node_by_name(node.name))
cell_operation = cast(Cell, node.operation)
target_node.update_operation(Cell(cell_operation.cell_name))
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
...@@ -241,7 +246,7 @@ class ManyChooseManyMutator(Mutator): ...@@ -241,7 +246,7 @@ class ManyChooseManyMutator(Mutator):
Choose based on labels. Will not affect the model itself. Choose based on labels. Will not affect the model itself.
""" """
def __init__(self, label: Optional[str]): def __init__(self, label: str):
super().__init__(label=label) super().__init__(label=label)
@staticmethod @staticmethod
...@@ -257,7 +262,7 @@ class ManyChooseManyMutator(Mutator): ...@@ -257,7 +262,7 @@ class ManyChooseManyMutator(Mutator):
return node.operation.parameters['n_chosen'] return node.operation.parameters['n_chosen']
return 1 return 1
def mutate(self, model: Model): def mutate(self, model: Model) -> None:
# this mutate does not have any effect, but it is recorded in the mutation history # this mutate does not have any effect, but it is recorded in the mutation history
for node in model.get_nodes_by_label(self.label): for node in model.get_nodes_by_label(self.label):
n_chosen = self.number_of_chosen(node) n_chosen = self.number_of_chosen(node)
...@@ -280,12 +285,12 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op ...@@ -280,12 +285,12 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
if not is_model_wrapped(pytorch_model): if not is_model_wrapped(pytorch_model):
raise ValueError('Please annotate the model with @model_wrapper decorator in python execution mode ' raise ValueError('Please annotate the model with @model_wrapper decorator in python execution mode '
'if your model has init parameters.') 'if your model has init parameters.')
model.python_init_params = pytorch_model.trace_kwargs model.python_init_params = cast(dict, pytorch_model.trace_kwargs)
else: else:
model.python_init_params = {} model.python_init_params = {}
# hyper-parameter choice # hyper-parameter choice
namespace: ModelNamespace = pytorch_model._model_namespace namespace: ModelNamespace = cast(ModelNamespace, pytorch_model._model_namespace)
for param_spec in namespace.parameter_specs: for param_spec in namespace.parameter_specs:
assert param_spec.categorical and param_spec.type == 'choice' assert param_spec.categorical and param_spec.type == 'choice'
node = graph.add_node(f'param_spec_{param_spec.name}', 'ModelParameterChoice', {'candidates': param_spec.values}) node = graph.add_node(f'param_spec_{param_spec.name}', 'ModelParameterChoice', {'candidates': param_spec.values})
...@@ -294,7 +299,8 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op ...@@ -294,7 +299,8 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
for name, module in pytorch_model.named_modules(): for name, module in pytorch_model.named_modules():
# tricky case: value choice that serves as parameters are stored in traced arguments # tricky case: value choice that serves as parameters are stored in traced arguments
if is_basic_unit(module): if is_basic_unit(module):
for key, value in module.trace_kwargs.items(): trace_kwargs = cast(Dict[str, Any], module.trace_kwargs)
for key, value in trace_kwargs.items():
if isinstance(value, ValueChoiceX): if isinstance(value, ValueChoiceX):
for i, choice in enumerate(value.inner_choices()): for i, choice in enumerate(value.inner_choices()):
node = graph.add_node(f'{name}.init.{key}.{i}', 'ValueChoice', {'candidates': choice.candidates}) node = graph.add_node(f'{name}.init.{key}.{i}', 'ValueChoice', {'candidates': choice.candidates})
...@@ -329,14 +335,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op ...@@ -329,14 +335,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
mutators = [] mutators = []
mutators_final = [] mutators_final = []
for nodes in _group_by_label_and_type(graph.hidden_nodes): for nodes in _group_by_label_and_type(graph.hidden_nodes):
label = nodes[0].label
assert label is not None, f'label of {nodes[0]} can not be None.'
assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \ assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \
f'Node with label "{nodes[0].label}" does not all have the same type.' f'Node with label "{label}" does not all have the same type.'
assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \ assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \
f'Node with label "{nodes[0].label}" does not agree on parameters.' f'Node with label "{label}" does not agree on parameters.'
if nodes[0].operation.type == 'NasBench101Cell': if nodes[0].operation.type == 'NasBench101Cell':
mutators_final.append(NasBench101Mutator(nodes[0].label)) # The mutation of Nas-bench-101 is special, and has to be done lastly.
mutators_final.append(NasBench101Mutator(label))
else: else:
mutators.append(ManyChooseManyMutator(nodes[0].label)) mutators.append(ManyChooseManyMutator(label))
return model, mutators + mutators_final return model, mutators + mutators_final
...@@ -350,7 +359,7 @@ class EvaluatorValueChoiceLeafMutator(Mutator): ...@@ -350,7 +359,7 @@ class EvaluatorValueChoiceLeafMutator(Mutator):
super().__init__(label=label) super().__init__(label=label)
self.candidates = candidates self.candidates = candidates
def mutate(self, model: Model) -> Model: def mutate(self, model: Model) -> None:
# leave a record here # leave a record here
# real mutations will be done in ParameterChoiceMutator # real mutations will be done in ParameterChoiceMutator
self.choice(self.candidates) self.choice(self.candidates)
...@@ -388,7 +397,7 @@ class EvaluatorValueChoiceMutator(Mutator): ...@@ -388,7 +397,7 @@ class EvaluatorValueChoiceMutator(Mutator):
return obj return obj
def mutate(self, model: Model): def mutate(self, model: Model) -> None:
value_choice_decisions = {} value_choice_decisions = {}
for mutation in model.history: for mutation in model.history:
if isinstance(mutation.mutator, EvaluatorValueChoiceLeafMutator): if isinstance(mutation.mutator, EvaluatorValueChoiceLeafMutator):
...@@ -454,7 +463,7 @@ def _is_all_equal(lst): ...@@ -454,7 +463,7 @@ def _is_all_equal(lst):
return True return True
def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]: def _group_by_label_and_type(nodes: Iterable[Node]) -> List[List[Node]]:
result = {} result = {}
for node in nodes: for node in nodes:
key = (node.label, node.operation.type) key = (node.label, node.operation.type)
...@@ -464,7 +473,7 @@ def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]: ...@@ -464,7 +473,7 @@ def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]:
return list(result.values()) return list(result.values())
def _group_by_label(nodes: List[Node]) -> List[List[Node]]: def _group_by_label(nodes: Iterable[Node]) -> List[List[Node]]:
result = {} result = {}
for node in nodes: for node in nodes:
label = node.operation.parameters['label'] label = node.operation.parameters['label']
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import Callable, List, Optional, Union, Dict from typing import Callable, List, Optional, Union, Dict, Tuple, cast
import numpy as np import numpy as np
import torch import torch
...@@ -89,7 +92,7 @@ def compute_vertex_channels(input_channels, output_channels, matrix): ...@@ -89,7 +92,7 @@ def compute_vertex_channels(input_channels, output_channels, matrix):
return vertex_channels return vertex_channels
def prune(matrix, ops): def prune(matrix, ops) -> Tuple[np.ndarray, List[Union[str, Callable[[int], nn.Module]]]]:
""" """
Prune the extraneous parts of the graph. Prune the extraneous parts of the graph.
...@@ -152,11 +155,17 @@ class _NasBench101CellFixed(nn.Module): ...@@ -152,11 +155,17 @@ class _NasBench101CellFixed(nn.Module):
assert num_nodes == len(operations) + 2 == len(adjacency_list) + 1 assert num_nodes == len(operations) + 2 == len(adjacency_list) + 1
self.operations = ['IN'] + operations + ['OUT'] # add psuedo nodes raw_operations: List[Union[str, Callable[[int], nn.Module]]] = list(operations)
del operations # operations is no longer needed. Delete it to avoid misuse
# add psuedo nodes
raw_operations.insert(0, 'IN')
raw_operations.append('OUT')
self.connection_matrix = self.build_connection_matrix(adjacency_list, num_nodes) self.connection_matrix = self.build_connection_matrix(adjacency_list, num_nodes)
del num_nodes # raw number of nodes is no longer used del num_nodes # raw number of nodes is no longer used
self.connection_matrix, self.operations = prune(self.connection_matrix, self.operations) self.connection_matrix, self.operations = prune(self.connection_matrix, raw_operations)
self.hidden_features = compute_vertex_channels(in_features, out_features, self.connection_matrix) self.hidden_features = compute_vertex_channels(in_features, out_features, self.connection_matrix)
...@@ -172,7 +181,8 @@ class _NasBench101CellFixed(nn.Module): ...@@ -172,7 +181,8 @@ class _NasBench101CellFixed(nn.Module):
self.projections.append(projection(in_features, self.hidden_features[i])) self.projections.append(projection(in_features, self.hidden_features[i]))
for i in range(1, self.num_nodes - 1): for i in range(1, self.num_nodes - 1):
self.ops.append(operations[i - 1](self.hidden_features[i])) operation = cast(Callable[[int], nn.Module], self.operations[i])
self.ops.append(operation(self.hidden_features[i]))
@staticmethod @staticmethod
def build_connection_matrix(adjacency_list, num_nodes): def build_connection_matrix(adjacency_list, num_nodes):
...@@ -361,7 +371,7 @@ class NasBench101Mutator(Mutator): ...@@ -361,7 +371,7 @@ class NasBench101Mutator(Mutator):
# for validation purposes # for validation purposes
# for python execution engine # for python execution engine
def __init__(self, label: Optional[str]): def __init__(self, label: str):
super().__init__(label=label) super().__init__(label=label)
@staticmethod @staticmethod
...@@ -378,9 +388,11 @@ class NasBench101Mutator(Mutator): ...@@ -378,9 +388,11 @@ class NasBench101Mutator(Mutator):
return 1 return 1
def mutate(self, model: Model): def mutate(self, model: Model):
max_num_edges = cast(int, None)
for node in model.get_nodes_by_label(self.label): for node in model.get_nodes_by_label(self.label):
max_num_edges = node.operation.parameters['max_num_edges'] max_num_edges = node.operation.parameters['max_num_edges']
break break
assert max_num_edges is not None
mutation_dict = {mut.mutator.label: mut.samples for mut in model.history} mutation_dict = {mut.mutator.label: mut.samples for mut in model.history}
num_nodes = mutation_dict[f'{self.label}/num_nodes'][0] num_nodes = mutation_dict[f'{self.label}/num_nodes'][0]
adjacency_list = [mutation_dict[f'{self.label}/input{i}'] for i in range(1, num_nodes)] adjacency_list = [mutation_dict[f'{self.label}/input{i}'] for i in range(1, num_nodes)]
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import inspect
import warnings
from pathlib import Path from pathlib import Path
import torch
import torch.nn as nn
# To make auto-completion happy, we generate a _nn.py that lists out all the classes. # To make auto-completion happy, we generate a _nn.py that lists out all the classes.
nn_cache_file_path = Path(__file__).parent / '_nn.py' nn_cache_file_path = Path(__file__).parent / '_nn.py'
cache_valid = False # Update this when cache format changes, to enforce an update.
cache_version = 2
def validate_cache() -> bool:
import torch
cache_valid = []
if nn_cache_file_path.exists():
lines = nn_cache_file_path.read_text().splitlines()
for line in lines:
if line.startswith('# _torch_version'):
_cached_torch_version = line[line.find('=') + 1:].strip()
if _cached_torch_version == torch.__version__:
cache_valid.append(True)
if line.startswith('# _torch_nn_cache_version'):
_cached_cache_version = int(line[line.find('=') + 1:].strip())
if _cached_cache_version == cache_version:
cache_valid.append(True)
return len(cache_valid) >= 2 and all(cache_valid)
if nn_cache_file_path.exists():
from . import _nn # pylint: disable=no-name-in-module
# valid only when torch version match
if _nn._torch_version == torch.__version__:
cache_valid = True
if not cache_valid: def generate_stub_file() -> str:
import inspect
import warnings
import torch
import torch.nn as nn
_NO_WRAP_CLASSES = [ _NO_WRAP_CLASSES = [
# not an nn.Module # not an nn.Module
'Parameter', 'Parameter',
...@@ -47,7 +63,10 @@ if not cache_valid: ...@@ -47,7 +63,10 @@ if not cache_valid:
'# This file is auto-generated to make auto-completion work.', '# This file is auto-generated to make auto-completion work.',
'# When pytorch version does not match, it will get automatically updated.', '# When pytorch version does not match, it will get automatically updated.',
'# pylint: skip-file', '# pylint: skip-file',
f'_torch_version = "{torch.__version__}"', '# pyright: reportGeneralTypeIssues=false',
f'# _torch_version = {torch.__version__}',
f'# _torch_nn_cache_version = {cache_version}',
'import typing',
'import torch.nn as nn', 'import torch.nn as nn',
'from nni.retiarii.serializer import basic_unit', 'from nni.retiarii.serializer import basic_unit',
] ]
...@@ -66,10 +85,9 @@ if not cache_valid: ...@@ -66,10 +85,9 @@ if not cache_valid:
'It means your PyTorch version might not be supported.', RuntimeWarning) 'It means your PyTorch version might not be supported.', RuntimeWarning)
code.append(f'{name} = nn.{name}') code.append(f'{name} = nn.{name}')
elif name in _WRAP_WITHOUT_TAG_CLASSES: elif name in _WRAP_WITHOUT_TAG_CLASSES:
code.append(f'{name} = basic_unit(nn.{name}, basic_unit_tag=False)') code.append(f'{name} = typing.cast(typing.Type[nn.{name}], basic_unit(nn.{name}, basic_unit_tag=False))')
else: else:
code.append(f'{name} = basic_unit(nn.{name})') code.append(f'{name} = typing.cast(typing.Type[nn.{name}], basic_unit(nn.{name}))')
all_names.append(name) all_names.append(name)
elif inspect.isfunction(obj) or inspect.ismodule(obj): elif inspect.isfunction(obj) or inspect.ismodule(obj):
...@@ -78,12 +96,19 @@ if not cache_valid: ...@@ -78,12 +96,19 @@ if not cache_valid:
code.append(f'__all__ = {all_names}') code.append(f'__all__ = {all_names}')
return '\n'.join(code)
def write_cache(code: str) -> None:
with nn_cache_file_path.open('w') as fp: with nn_cache_file_path.open('w') as fp:
fp.write('\n'.join(code)) fp.write(code)
code = generate_stub_file()
if not validate_cache():
write_cache(code)
# Import all modules from generated _nn.py del Path, validate_cache, write_cache, cache_version, nn_cache_file_path, code
from . import _nn # pylint: disable=no-name-in-module from ._nn import * # pylint: disable=import-error, wildcard-import, unused-wildcard-import
__all__ = _nn.__all__
from ._nn import * # pylint: disable=import-error, wildcard-import
...@@ -20,7 +20,7 @@ from .supermodule.base import BaseSuperNetModule ...@@ -20,7 +20,7 @@ from .supermodule.base import BaseSuperNetModule
__all__ = ['MutationHook', 'BaseSuperNetModule', 'BaseOneShotLightningModule', 'traverse_and_mutate_submodules'] __all__ = ['MutationHook', 'BaseSuperNetModule', 'BaseOneShotLightningModule', 'traverse_and_mutate_submodules']
MutationHook = Callable[[nn.Module, str, Dict[str, Any]], Union[nn.Module, bool, Tuple[nn.Module, bool]]] MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[nn.Module, bool, Tuple[nn.Module, bool]]]
def traverse_and_mutate_submodules( def traverse_and_mutate_submodules(
...@@ -149,11 +149,12 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -149,11 +149,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
The hook list will be appended by ``default_mutation_hooks`` in each one-shot module. The hook list will be appended by ``default_mutation_hooks`` in each one-shot module.
To be more specific, the input arguments are three arguments: To be more specific, the input arguments are four arguments:
#. a module that might be processed, #. a module that might be processed,
#. name of the module in its parent module, #. name of the module in its parent module,
#. a memo dict whose usage depends on the particular algorithm. #. a memo dict whose usage depends on the particular algorithm.
#. keyword arguments (configurations).
Note that the memo should be read/written by hooks. Note that the memo should be read/written by hooks.
There won't be any hooks called on root module. There won't be any hooks called on root module.
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
# type: ignore
import copy import copy
import logging import logging
from collections import OrderedDict from collections import OrderedDict
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
# type: ignore
import logging import logging
import torch import torch
......
...@@ -27,7 +27,7 @@ which interprets the slice and apply it on a tensor. ...@@ -27,7 +27,7 @@ which interprets the slice and apply it on a tensor.
""" """
import operator import operator
from typing import Tuple, Union, List, Dict, Callable, Optional, Iterator, TypeVar, Any, Generic from typing import Tuple, Union, List, Dict, Callable, Optional, Iterator, TypeVar, Any, Generic, cast
import numpy as np import numpy as np
import torch import torch
...@@ -128,9 +128,10 @@ class Slicable(Generic[T]): ...@@ -128,9 +128,10 @@ class Slicable(Generic[T]):
raise TypeError(f'Unsuppoted weight type: {type(weight)}') raise TypeError(f'Unsuppoted weight type: {type(weight)}')
self.weight = weight self.weight = weight
def __getitem__(self, index: multidim_slice) -> T: def __getitem__(self, index: Union[slice_type, multidim_slice]) -> T:
if not isinstance(index, tuple): if not isinstance(index, tuple):
index = (index, ) index = (index, )
index = cast(multidim_slice, index)
# Get the dict value in index's leafs # Get the dict value in index's leafs
# There can be at most one dict # There can be at most one dict
......
...@@ -24,7 +24,7 @@ class BaseSuperNetModule(nn.Module): ...@@ -24,7 +24,7 @@ class BaseSuperNetModule(nn.Module):
rather than their compositions. rather than their compositions.
""" """
def resample(self, memo: Dict[str, Any] = None) -> Dict[str, Any]: def resample(self, memo: Dict[str, Any]) -> Dict[str, Any]:
""" """
Resample the super-net module. Resample the super-net module.
...@@ -40,7 +40,7 @@ class BaseSuperNetModule(nn.Module): ...@@ -40,7 +40,7 @@ class BaseSuperNetModule(nn.Module):
""" """
raise NotImplementedError() raise NotImplementedError()
def export(self, memo: Dict[str, Any] = None) -> Dict[str, Any]: def export(self, memo: Dict[str, Any]) -> Dict[str, Any]:
""" """
Export the final architecture within this module. Export the final architecture within this module.
It should have the same keys as ``search_space_spec()``. It should have the same keys as ``search_space_spec()``.
......
...@@ -275,11 +275,11 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy): ...@@ -275,11 +275,11 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
if not arch: if not arch:
yield name, p yield name, p
def resample(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]: def resample(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
"""Differentiable. Do nothing in resample.""" """Differentiable. Do nothing in resample."""
return {} return {}
def export(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]: def export(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
"""Export is also random for each leaf value choice.""" """Export is also random for each leaf value choice."""
result = {} result = {}
for name, spec in operation.search_space_spec().items(): for name, spec in operation.search_space_spec().items():
......
...@@ -8,11 +8,12 @@ which is commonly known as super-kernel (as in channel search), or weight entang ...@@ -8,11 +8,12 @@ which is commonly known as super-kernel (as in channel search), or weight entang
import inspect import inspect
import itertools import itertools
from typing import Union, Tuple, Dict, List, Any, Type, Optional, TypeVar from typing import Union, Tuple, Dict, List, Any, Type, Optional, TypeVar, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
import nni.retiarii.nn.pytorch as retiarii_nn import nni.retiarii.nn.pytorch as retiarii_nn
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
...@@ -46,11 +47,11 @@ class MixedOperationSamplingPolicy: ...@@ -46,11 +47,11 @@ class MixedOperationSamplingPolicy:
""" """
pass pass
def resample(self, operation: 'MixedOperation', memo: Dict[str, Any] = None) -> Dict[str, Any]: def resample(self, operation: 'MixedOperation', memo: Dict[str, Any]) -> Dict[str, Any]:
"""The handler of :meth:`MixedOperation.resample`.""" """The handler of :meth:`MixedOperation.resample`."""
raise NotImplementedError() raise NotImplementedError()
def export(self, operation: 'MixedOperation', memo: Dict[str, Any] = None) -> Dict[str, Any]: def export(self, operation: 'MixedOperation', memo: Dict[str, Any]) -> Dict[str, Any]:
"""The handler of :meth:`MixedOperation.export`.""" """The handler of :meth:`MixedOperation.export`."""
raise NotImplementedError() raise NotImplementedError()
...@@ -513,43 +514,42 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): ...@@ -513,43 +514,42 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
embed_dim = _W(embed_dim) embed_dim = _W(embed_dim)
# in projection weights & biases has q, k, v weights concatenated together # in projection weights & biases has q, k, v weights concatenated together
in_proj_bias = in_proj_weight = None in_proj_bias: Optional[Tensor] = None
in_proj_weight: Optional[Tensor] = None
if self.in_proj_bias is not None: if self.in_proj_bias is not None:
in_proj_bias = _S(self.in_proj_bias)[self._to_proj_slice(embed_dim)] in_proj_bias = _S(cast(Tensor, self.in_proj_bias))[self._to_proj_slice(embed_dim)]
if self.in_proj_weight is not None: if self.in_proj_weight is not None:
in_proj_weight = _S(self.in_proj_weight)[self._to_proj_slice(embed_dim), :embed_dim] in_proj_weight = _S(cast(Tensor, self.in_proj_weight))[self._to_proj_slice(embed_dim), :embed_dim]
bias_k = _S(self.bias_k)[:, :, :embed_dim] if self.bias_k is not None else None bias_k = _S(cast(Tensor, self.bias_k))[:, :, :embed_dim] if self.bias_k is not None else None
bias_v = _S(self.bias_v)[:, :, :embed_dim] if self.bias_v is not None else None bias_v = _S(cast(Tensor, self.bias_v))[:, :, :embed_dim] if self.bias_v is not None else None
out_proj_weight = _S(self.out_proj.weight)[:embed_dim, :embed_dim] out_proj_weight = _S(cast(Tensor, self.out_proj.weight))[:embed_dim, :embed_dim]
out_proj_bias = _S(self.out_proj.bias)[:embed_dim] if self.out_proj.bias is not None else None out_proj_bias = _S(cast(Tensor, self.out_proj.bias))[:embed_dim] if self.out_proj.bias is not None else None
if not qkv_same_embed_dim: if not qkv_same_embed_dim:
kdim = _W(kdim) q_proj = _S(cast(Tensor, self.q_proj_weight))[:embed_dim, :embed_dim]
vdim = _W(vdim) k_proj = _S(cast(Tensor, self.k_proj_weight))[:embed_dim]
k_proj = _S(k_proj)[:, :_W(kdim)]
q_proj = _S(self.q_proj_weight)[:embed_dim, :embed_dim] v_proj = _S(cast(Tensor, self.v_proj_weight))[:embed_dim]
k_proj = _S(self.k_proj_weight)[:embed_dim] v_proj = _S(v_proj)[:, :_W(vdim)]
k_proj = _S(k_proj)[:, :kdim]
v_proj = _S(self.v_proj_weight)[:embed_dim]
v_proj = _S(v_proj)[:, :vdim]
# The rest part is basically same as pytorch # The rest part is basically same as pytorch
attn_output, attn_output_weights = F.multi_head_attention_forward( attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, used_embed_dim, num_heads, query, key, value, used_embed_dim, num_heads,
in_proj_weight, in_proj_bias, cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias),
bias_k, bias_v, self.add_zero_attn, bias_k, bias_v, self.add_zero_attn,
dropout, out_proj_weight, out_proj_bias, dropout, out_proj_weight, cast(Tensor, out_proj_bias),
training=self.training, training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights, key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True, attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=q_proj, k_proj_weight=k_proj, v_proj_weight=v_proj) q_proj_weight=q_proj, k_proj_weight=k_proj, v_proj_weight=v_proj)
else: else:
# Cast tensor here because of a bug in pytorch stub
attn_output, attn_output_weights = F.multi_head_attention_forward( attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, used_embed_dim, num_heads, query, key, value, used_embed_dim, num_heads,
in_proj_weight, in_proj_bias, cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias),
bias_k, bias_v, self.add_zero_attn, bias_k, bias_v, self.add_zero_attn,
dropout, out_proj_weight, out_proj_bias, dropout, out_proj_weight, cast(Tensor, out_proj_bias),
training=self.training, training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights, key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask) attn_mask=attn_mask)
......
...@@ -9,7 +9,7 @@ The support remains limited. Known limitations include: ...@@ -9,7 +9,7 @@ The support remains limited. Known limitations include:
- The code contains duplicates. Needs refactor. - The code contains duplicates. Needs refactor.
""" """
from typing import List, Tuple, Optional from typing import List, Tuple, Optional, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -94,7 +94,7 @@ class ProxylessMixedLayer(DifferentiableMixedLayer): ...@@ -94,7 +94,7 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
self._sample_idx = self.op_names.index(self._sampled) self._sample_idx = self.op_names.index(self._sampled)
else: else:
probs = self._softmax(self._arch_alpha) probs = self._softmax(self._arch_alpha)
self._sample_idx = torch.multinomial(probs, 1)[0].item() self._sample_idx = int(torch.multinomial(probs, 1)[0].item())
self._sampled = self.op_names[self._sample_idx] self._sampled = self.op_names[self._sample_idx]
# set binary gates # set binary gates
...@@ -109,10 +109,11 @@ class ProxylessMixedLayer(DifferentiableMixedLayer): ...@@ -109,10 +109,11 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
"""Chose the argmax if label isn't found in memo.""" """Chose the argmax if label isn't found in memo."""
if self.label in memo: if self.label in memo:
return {} # nothing new to export return {} # nothing new to export
return {self.label: self.op_names[torch.argmax(self._arch_alpha).item()]} return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def finalize_grad(self): def finalize_grad(self):
binary_grads = self._binary_gates.grad binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad(): with torch.no_grad():
if self._arch_alpha.grad is None: if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data) self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
...@@ -164,13 +165,13 @@ class ProxylessMixedInput(DifferentiableMixedInput): ...@@ -164,13 +165,13 @@ class ProxylessMixedInput(DifferentiableMixedInput):
else: else:
probs = self._softmax(self._arch_alpha) probs = self._softmax(self._arch_alpha)
sample = torch.multinomial(probs, 1)[0].item() sample = torch.multinomial(probs, 1)[0].item()
self._sampled = sample self._sampled = int(sample)
# set binary gates # set binary gates
with torch.no_grad(): with torch.no_grad():
self._binary_gates.zero_() self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data) self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[sample] = 1.0 self._binary_gates.data[cast(int, self._sampled)] = 1.0
return {self.label: self._sampled} return {self.label: self._sampled}
...@@ -182,6 +183,7 @@ class ProxylessMixedInput(DifferentiableMixedInput): ...@@ -182,6 +183,7 @@ class ProxylessMixedInput(DifferentiableMixedInput):
def finalize_grad(self): def finalize_grad(self):
binary_grads = self._binary_gates.grad binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad(): with torch.no_grad():
if self._arch_alpha.grad is None: if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data) self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
......
...@@ -129,6 +129,8 @@ class PathSamplingInput(BaseSuperNetModule): ...@@ -129,6 +129,8 @@ class PathSamplingInput(BaseSuperNetModule):
if isinstance(module, InputChoice): if isinstance(module, InputChoice):
if module.reduction not in ['sum', 'mean', 'concat']: if module.reduction not in ['sum', 'mean', 'concat']:
raise ValueError('Only input choice of sum/mean/concat reduction is supported.') raise ValueError('Only input choice of sum/mean/concat reduction is supported.')
if module.n_chosen is None:
raise ValueError('n_chosen is None is not supported yet.')
return cls(module.n_candidates, module.n_chosen, module.reduction, module.label) return cls(module.n_candidates, module.n_chosen, module.reduction, module.label)
def forward(self, input_tensors): def forward(self, input_tensors):
...@@ -161,7 +163,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy): ...@@ -161,7 +163,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
# Sampling arguments. This should have the same keys with `operation.mutable_arguments` # Sampling arguments. This should have the same keys with `operation.mutable_arguments`
self._sampled: Optional[Dict[str, Any]] = None self._sampled: Optional[Dict[str, Any]] = None
def resample(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]: def resample(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
"""Random sample for each leaf value choice.""" """Random sample for each leaf value choice."""
result = {} result = {}
space_spec = operation.search_space_spec() space_spec = operation.search_space_spec()
...@@ -179,7 +181,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy): ...@@ -179,7 +181,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
return result return result
def export(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]: def export(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
"""Export is also random for each leaf value choice.""" """Export is also random for each leaf value choice."""
result = {} result = {}
space_spec = operation.search_space_spec() space_spec = operation.search_space_spec()
......
...@@ -132,7 +132,7 @@ def _replace_module_with_type(root_module, init_fn, type_name, modules): ...@@ -132,7 +132,7 @@ def _replace_module_with_type(root_module, init_fn, type_name, modules):
for name, child in m.named_children(): for name, child in m.named_children():
if isinstance(child, type_name): if isinstance(child, type_name):
setattr(m, name, init_fn(child)) setattr(m, name, init_fn(child))
modules.append((child.key, getattr(m, name))) modules.append((child.label, getattr(m, name)))
else: else:
apply(child) apply(child)
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import (Any, Dict, List) from typing import (Any, Dict, List, Optional, cast)
from . import debug_configs from . import debug_configs
...@@ -34,6 +34,8 @@ class Operation: ...@@ -34,6 +34,8 @@ class Operation:
Arbitrary key-value parameters (e.g. kernel_size). Arbitrary key-value parameters (e.g. kernel_size).
""" """
io_names: List[str] = []
def __init__(self, type_name: str, parameters: Dict[str, Any] = {}, _internal: bool = False, attributes: Dict[str, Any] = {}): def __init__(self, type_name: str, parameters: Dict[str, Any] = {}, _internal: bool = False, attributes: Dict[str, Any] = {}):
assert _internal, '`Operation()` is private, use `Operation.new()` instead' assert _internal, '`Operation()` is private, use `Operation.new()` instead'
self.type: str = type_name self.type: str = type_name
...@@ -43,7 +45,7 @@ class Operation: ...@@ -43,7 +45,7 @@ class Operation:
def to_init_code(self, field: str) -> str: def to_init_code(self, field: str) -> str:
raise NotImplementedError() raise NotImplementedError()
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
raise NotImplementedError() raise NotImplementedError()
def _to_class_name(self) -> str: def _to_class_name(self) -> str:
...@@ -53,8 +55,8 @@ class Operation: ...@@ -53,8 +55,8 @@ class Operation:
return True return True
@staticmethod @staticmethod
def new(type_name: str, parameters: Dict[str, Any] = None, cell_name: str = None, def new(type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None), cell_name: str = cast(str, None),
attributes: Dict[str, Any] = None) -> 'Operation': attributes: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Operation':
parameters = parameters or {} parameters = parameters or {}
attributes = attributes or {} attributes = attributes or {}
if type_name == '_cell': if type_name == '_cell':
...@@ -98,16 +100,16 @@ class PyTorchOperation(Operation): ...@@ -98,16 +100,16 @@ class PyTorchOperation(Operation):
subclass_name = 'FunctionalOperator' subclass_name = 'FunctionalOperator'
for subclass in cls.__subclasses__(): for subclass in cls.__subclasses__():
if hasattr(subclass, '_ori_type_name') and \ if hasattr(subclass, '_ori_type_name') and \
subclass_name in subclass._ori_type_name: subclass_name in cast(Any, subclass)._ori_type_name:
return subclass return subclass
for subclass in cls.__subclasses__(): for subclass in cls.__subclasses__():
if hasattr(subclass, '_artificial_op_name') and \ if hasattr(subclass, '_artificial_op_name') and \
subclass_name in subclass._artificial_op_name: subclass_name in cast(Any, subclass)._artificial_op_name:
return subclass return subclass
return cls return cls
@classmethod @classmethod
def to_class_name(cls, type_name) -> str: def to_class_name(cls, type_name) -> Optional[str]:
if type_name.startswith('__torch__.'): if type_name.startswith('__torch__.'):
return type_name[len('__torch__.'):] return type_name[len('__torch__.'):]
elif type_name.startswith('__mutated__.'): elif type_name.startswith('__mutated__.'):
...@@ -119,7 +121,7 @@ class PyTorchOperation(Operation): ...@@ -119,7 +121,7 @@ class PyTorchOperation(Operation):
def is_functional(cls, type_name) -> bool: def is_functional(cls, type_name) -> bool:
return type_name.startswith('Function.') return type_name.startswith('Function.')
def _to_class_name(self) -> str: def _to_class_name(self) -> Optional[str]:
if self.type.startswith('__torch__.'): if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):] return self.type[len('__torch__.'):]
elif self.type.startswith('__mutated__.'): elif self.type.startswith('__mutated__.'):
...@@ -127,7 +129,7 @@ class PyTorchOperation(Operation): ...@@ -127,7 +129,7 @@ class PyTorchOperation(Operation):
else: else:
return None return None
def get_import_pkg(self) -> str: def get_import_pkg(self) -> Optional[str]:
if self.type.startswith('__torch__.'): if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):].split('.')[0] return self.type[len('__torch__.'):].split('.')[0]
elif self.type.startswith('__mutated__.'): elif self.type.startswith('__mutated__.'):
...@@ -135,14 +137,14 @@ class PyTorchOperation(Operation): ...@@ -135,14 +137,14 @@ class PyTorchOperation(Operation):
else: else:
return None return None
def to_init_code(self, field: str) -> str: def to_init_code(self, field: str) -> Optional[str]:
if self._to_class_name() is not None: if self._to_class_name() is not None:
assert 'positional_args' not in self.parameters assert 'positional_args' not in self.parameters
kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items()) kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items())
return f'self.{field} = {self._to_class_name()}({kw_params})' return f'self.{field} = {self._to_class_name()}({kw_params})'
return None return None
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
""" """
Parameters Parameters
---------- ----------
...@@ -207,7 +209,9 @@ class Cell(PyTorchOperation): ...@@ -207,7 +209,9 @@ class Cell(PyTorchOperation):
No real usage. Exists for compatibility with base class. No real usage. Exists for compatibility with base class.
""" """
def __init__(self, cell_name: str, parameters: Dict[str, Any] = None, attributes: Dict[str, Any] = None): def __init__(self, cell_name: str,
parameters: Dict[str, Any] = cast(Dict[str, Any], None),
attributes: Dict[str, Any] = cast(Dict[str, Any], None)):
self.type = '_cell' self.type = '_cell'
self.cell_name = cell_name self.cell_name = cell_name
self.parameters = parameters or {} self.parameters = parameters or {}
...@@ -217,7 +221,7 @@ class Cell(PyTorchOperation): ...@@ -217,7 +221,7 @@ class Cell(PyTorchOperation):
# TODO: ugly, think about how to refactor this part # TODO: ugly, think about how to refactor this part
return _convert_name(self.cell_name) return _convert_name(self.cell_name)
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = self.{field}({", ".join(inputs)})' return f'{output} = self.{field}({", ".join(inputs)})'
class _IOPseudoOperation(Operation): class _IOPseudoOperation(Operation):
...@@ -227,7 +231,7 @@ class _IOPseudoOperation(Operation): ...@@ -227,7 +231,7 @@ class _IOPseudoOperation(Operation):
especially in static type checking. especially in static type checking.
""" """
def __init__(self, type_name: str, io_names: List = None): def __init__(self, type_name: str, io_names: List[str] = cast(List[str], None)):
assert type_name.startswith('_') assert type_name.startswith('_')
super(_IOPseudoOperation, self).__init__(type_name, {}, True) super(_IOPseudoOperation, self).__init__(type_name, {}, True)
self.io_names = io_names self.io_names = io_names
...@@ -235,7 +239,7 @@ class _IOPseudoOperation(Operation): ...@@ -235,7 +239,7 @@ class _IOPseudoOperation(Operation):
def to_init_code(self, field: str) -> str: def to_init_code(self, field: str) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"') raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"') raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def __bool__(self) -> bool: def __bool__(self) -> bool:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
""" """
Definition of operation types. Definition of operation types.
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
from typing import (Any, Dict, List) from typing import (Any, Dict, List)
import torch import torch
import torch.nn.functional as nn_functional
from ..operation import PyTorchOperation from ..operation import PyTorchOperation
...@@ -39,23 +42,23 @@ class NoOpIdentity(PyTorchOperation): ...@@ -39,23 +42,23 @@ class NoOpIdentity(PyTorchOperation):
""" """
_ori_type_name = ['noop_identity'] _ori_type_name = ['noop_identity']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {", ".join(inputs)}' return f'{output} = {", ".join(inputs)}'
class ModuleOperator(PyTorchOperation): class ModuleOperator(PyTorchOperation):
_ori_type_name = ['ModuleOperator', 'shared'] _ori_type_name = ['ModuleOperator', 'shared']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = self.{field}({", ".join(inputs)})' return f'{output} = self.{field}({", ".join(inputs)})'
class FunctionalOperator(PyTorchOperation): class FunctionalOperator(PyTorchOperation):
_ori_type_name = ['FunctionalOperator'] _ori_type_name = ['FunctionalOperator']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
func_name = self.type[len('Function.'):] func_name = self.type[len('Function.'):]
if not hasattr(torch.nn.functional, func_name): if not hasattr(nn_functional, func_name):
raise RuntimeError('For now, we only support calling independent functions from `torch.nn.functional`, ' raise RuntimeError('For now, we only support calling independent functions from `torch.nn.functional`, '
f'{func_name} is not in it.') f'{func_name} is not in it.')
return f'{output} = F.{func_name}({", ".join(inputs)})' return f'{output} = F.{func_name}({", ".join(inputs)})'
...@@ -64,7 +67,7 @@ class FunctionalOperator(PyTorchOperation): ...@@ -64,7 +67,7 @@ class FunctionalOperator(PyTorchOperation):
class PrimConstant(PyTorchOperation): class PrimConstant(PyTorchOperation):
_ori_type_name = ['prim::Constant'] _ori_type_name = ['prim::Constant']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant # TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types # TODO: deal with all the types
if self.parameters['type'] in ['None', 'NoneType']: if self.parameters['type'] in ['None', 'NoneType']:
...@@ -87,28 +90,28 @@ class PrimConstant(PyTorchOperation): ...@@ -87,28 +90,28 @@ class PrimConstant(PyTorchOperation):
class PrimListConstruct(PyTorchOperation): class PrimListConstruct(PyTorchOperation):
_ori_type_name = ['prim::ListConstruct'] _ori_type_name = ['prim::ListConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = [{", ".join(inputs)}]' return f'{output} = [{", ".join(inputs)}]'
class PrimListUnpack(PyTorchOperation): class PrimListUnpack(PyTorchOperation):
_ori_type_name = ['prim::ListUnpack'] _ori_type_name = ['prim::ListUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}' return f'{output} = {inputs[0]}'
class PrimTupleConstruct(PyTorchOperation): class PrimTupleConstruct(PyTorchOperation):
_ori_type_name = ['prim::TupleConstruct'] _ori_type_name = ['prim::TupleConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = ({", ".join(inputs)})' return f'{output} = ({", ".join(inputs)})'
class PrimTupleUnpack(PyTorchOperation): class PrimTupleUnpack(PyTorchOperation):
_ori_type_name = ['prim::TupleUnpack'] _ori_type_name = ['prim::TupleUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# have single output here, because the following code uses index to access the unpacked values # have single output here, because the following code uses index to access the unpacked values
assert len(inputs) == 1 assert len(inputs) == 1
return f'{output} = {inputs[0]}' return f'{output} = {inputs[0]}'
...@@ -117,7 +120,7 @@ class PrimTupleUnpack(PyTorchOperation): ...@@ -117,7 +120,7 @@ class PrimTupleUnpack(PyTorchOperation):
class PrimGetAttr(PyTorchOperation): class PrimGetAttr(PyTorchOperation):
_ori_type_name = ['prim::GetAttr'] _ori_type_name = ['prim::GetAttr']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.parameters['value'] is not None: if self.parameters['value'] is not None:
return f"{output} = {self.parameters['value']}" return f"{output} = {self.parameters['value']}"
else: else:
...@@ -127,14 +130,14 @@ class PrimGetAttr(PyTorchOperation): ...@@ -127,14 +130,14 @@ class PrimGetAttr(PyTorchOperation):
class PrimUncheckedCast(PyTorchOperation): class PrimUncheckedCast(PyTorchOperation):
_ori_type_name = ['prim::unchecked_cast'] _ori_type_name = ['prim::unchecked_cast']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}' return f'{output} = {inputs[0]}'
class SimpleMember(PyTorchOperation): class SimpleMember(PyTorchOperation):
_ori_type_name = ['prim::is_cuda', 'prim::data'] _ori_type_name = ['prim::is_cuda', 'prim::data']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
member_name = self.type.split('::')[-1] member_name = self.type.split('::')[-1]
return f'{output} = {inputs[0]}.{member_name}' return f'{output} = {inputs[0]}.{member_name}'
...@@ -142,16 +145,16 @@ class SimpleMember(PyTorchOperation): ...@@ -142,16 +145,16 @@ class SimpleMember(PyTorchOperation):
class AtenContiguous(PyTorchOperation): class AtenContiguous(PyTorchOperation):
_ori_type_name = ['aten::contiguous'] _ori_type_name = ['aten::contiguous']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# defined in pytorch/c10/core/MemoryFormat.h # defined in pytorch/c10/core/MemoryFormat.h
assert inputs_value[1] in [0, 1, 2] assert inputs_value is not None and inputs_value[1] in [0, 1, 2]
return f'{output} = {inputs[0]}.contiguous(memory_format={mem_format[inputs_value[1]]})' return f'{output} = {inputs[0]}.contiguous(memory_format={mem_format[inputs_value[1]]})'
class AtenGetitem(PyTorchOperation): class AtenGetitem(PyTorchOperation):
_ori_type_name = ['aten::__getitem__'] _ori_type_name = ['aten::__getitem__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2 assert len(inputs) == 2
return f'{output} = {inputs[0]}[{inputs[1]}]' return f'{output} = {inputs[0]}[{inputs[1]}]'
...@@ -159,7 +162,7 @@ class AtenGetitem(PyTorchOperation): ...@@ -159,7 +162,7 @@ class AtenGetitem(PyTorchOperation):
class AtenAppend(PyTorchOperation): class AtenAppend(PyTorchOperation):
_ori_type_name = ['aten::append'] _ori_type_name = ['aten::append']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2 assert len(inputs) == 2
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}' return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
...@@ -167,7 +170,7 @@ class AtenAppend(PyTorchOperation): ...@@ -167,7 +170,7 @@ class AtenAppend(PyTorchOperation):
class MergedSlice(PyTorchOperation): class MergedSlice(PyTorchOperation):
_ori_type_name = ['MergedSlice'] _ori_type_name = ['MergedSlice']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if (len(inputs) - 1) % 4 == 0: if (len(inputs) - 1) % 4 == 0:
slices = [] slices = []
dim = int((len(inputs) - 1) / 4) dim = int((len(inputs) - 1) / 4)
...@@ -187,21 +190,21 @@ class MergedSlice(PyTorchOperation): ...@@ -187,21 +190,21 @@ class MergedSlice(PyTorchOperation):
class AtenBool(PyTorchOperation): class AtenBool(PyTorchOperation):
_ori_type_name = ['aten::Bool'] _ori_type_name = ['aten::Bool']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = bool({inputs[0]})' return f'{output} = bool({inputs[0]})'
class AtenNot(PyTorchOperation): class AtenNot(PyTorchOperation):
_ori_type_name = ['aten::__not__'] _ori_type_name = ['aten::__not__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = not {inputs[0]}' return f'{output} = not {inputs[0]}'
class AtenCat(PyTorchOperation): class AtenCat(PyTorchOperation):
_ori_type_name = ['aten::cat'] _ori_type_name = ['aten::cat']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2 assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})' return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
...@@ -215,7 +218,7 @@ class AtenTensors(PyTorchOperation): ...@@ -215,7 +218,7 @@ class AtenTensors(PyTorchOperation):
'aten::new_empty', 'aten::new_zeros', 'aten::arange', 'aten::new_empty', 'aten::new_zeros', 'aten::arange',
'aten::tensor', 'aten::ones', 'aten::zeros', 'aten::as_tensor'] 'aten::tensor', 'aten::ones', 'aten::zeros', 'aten::as_tensor']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
schemas = torch._C._jit_get_schemas_for_operator(self.type) schemas = torch._C._jit_get_schemas_for_operator(self.type)
# match number of inputs # match number of inputs
overloaded_defs = [len(s.arguments) for s in schemas] overloaded_defs = [len(s.arguments) for s in schemas]
...@@ -257,40 +260,41 @@ class AtenTensors(PyTorchOperation): ...@@ -257,40 +260,41 @@ class AtenTensors(PyTorchOperation):
class AtenFloordiv(PyTorchOperation): class AtenFloordiv(PyTorchOperation):
_ori_type_name = ['aten::floordiv'] _ori_type_name = ['aten::floordiv']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]} // {inputs[1]}' return f'{output} = {inputs[0]} // {inputs[1]}'
class AtenMul(PyTorchOperation): class AtenMul(PyTorchOperation):
_ori_type_name = ['aten::mul'] _ori_type_name = ['aten::mul']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]} * {inputs[1]}' return f'{output} = {inputs[0]} * {inputs[1]}'
class AtenLen(PyTorchOperation): class AtenLen(PyTorchOperation):
_ori_type_name = ['aten::len'] _ori_type_name = ['aten::len']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = len({inputs[0]})' return f'{output} = len({inputs[0]})'
class AtenIntImplicit(PyTorchOperation): class AtenIntImplicit(PyTorchOperation):
_ori_type_name = ['aten::IntImplicit', 'aten::Float', 'aten::Int', 'aten::ScalarImplicit'] _ori_type_name = ['aten::IntImplicit', 'aten::Float', 'aten::Int', 'aten::ScalarImplicit']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.type.endswith('Implicit'): if self.type.endswith('Implicit'):
return f'{output} = {inputs[0]}' return f'{output} = {inputs[0]}'
elif self.type == 'aten::Int': elif self.type == 'aten::Int':
return f'{output} = int({inputs[0]})' return f'{output} = int({inputs[0]})'
elif self.type == 'aten::Float': elif self.type == 'aten::Float':
return f'{output} = float({inputs[0]})' return f'{output} = float({inputs[0]})'
raise TypeError(f'Unexpected type: {self.type}')
class AtenIndex(PyTorchOperation): class AtenIndex(PyTorchOperation):
_ori_type_name = ['aten::index'] _ori_type_name = ['aten::index']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}[{inputs[1]}]' return f'{output} = {inputs[0]}[{inputs[1]}]'
...@@ -355,13 +359,13 @@ def _get_tensor_ops(): ...@@ -355,13 +359,13 @@ def _get_tensor_ops():
def _get_torch_ops(): def _get_torch_ops():
torch_op_args = {} torch_op_args = {}
for mod in torch.jit._builtins._modules_containing_builtins: for mod in torch.jit._builtins._modules_containing_builtins: # type: ignore
name = mod.__name__ name = mod.__name__
if name == 'torch._C._nn': if name == 'torch._C._nn':
continue continue
# only process 'torch.XXX' # only process 'torch.XXX'
for elem in dir(mod): for elem in dir(mod):
builtin = torch.jit._builtins._find_builtin(getattr(mod, elem)) builtin = torch.jit._builtins._find_builtin(getattr(mod, elem)) # type: ignore
if builtin is not None: if builtin is not None:
schemas = torch._C._jit_get_schemas_for_operator(builtin) schemas = torch._C._jit_get_schemas_for_operator(builtin)
for schema in schemas: for schema in schemas:
...@@ -436,7 +440,7 @@ class TensorOps(PyTorchOperation): ...@@ -436,7 +440,7 @@ class TensorOps(PyTorchOperation):
return None return None
raise RuntimeError(f'tensor op type {_type} has no matched') raise RuntimeError(f'tensor op type {_type} has no matched')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# TODO: deal with conditional ops # TODO: deal with conditional ops
if self.type in TensorOps.comparison_ops: if self.type in TensorOps.comparison_ops:
return f'{output} = ({inputs[0]} {TensorOps.comparison_ops[self.type]} {inputs[1]})' return f'{output} = ({inputs[0]} {TensorOps.comparison_ops[self.type]} {inputs[1]})'
...@@ -486,7 +490,7 @@ class TorchOps(PyTorchOperation): ...@@ -486,7 +490,7 @@ class TorchOps(PyTorchOperation):
else: else:
raise RuntimeError(f'torch op type {_type} has no matched') raise RuntimeError(f'torch op type {_type} has no matched')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
matched_args = TorchOps._get_matched_args(self.type, inputs) matched_args = TorchOps._get_matched_args(self.type, inputs)
op_name = self.type.split('::')[-1] op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i]}' if t.startswith('Optional[') else f'{inputs[i]}' args_str = ', '.join([f'{name}={inputs[i]}' if t.startswith('Optional[') else f'{inputs[i]}'
...@@ -498,7 +502,7 @@ class AtenAvgpool2d(PyTorchOperation): ...@@ -498,7 +502,7 @@ class AtenAvgpool2d(PyTorchOperation):
# NOTE: it is not included in the above aten ops for unkown reason # NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name = ['aten::avg_pool2d'] _ori_type_name = ['aten::avg_pool2d']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = F.avg_pool2d({", ".join(inputs)})' return f'{output} = F.avg_pool2d({", ".join(inputs)})'
...@@ -506,7 +510,7 @@ class ToDevice(PyTorchOperation): ...@@ -506,7 +510,7 @@ class ToDevice(PyTorchOperation):
_artificial_op_name = "ToDevice" _artificial_op_name = "ToDevice"
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False, def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False,
attributes: Dict[str, Any] = None): attributes: Dict[str, Any] = {}):
self.type = "ToDevice" self.type = "ToDevice"
self.device = parameters['device'] self.device = parameters['device']
self.overridden_device_repr = None self.overridden_device_repr = None
...@@ -540,5 +544,5 @@ class AtenDet(PyTorchOperation): ...@@ -540,5 +544,5 @@ class AtenDet(PyTorchOperation):
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det # NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
_ori_type_name = ['aten::linalg_det'] _ori_type_name = ['aten::linalg_det']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = torch.det({inputs[0]})' return f'{output} = torch.det({inputs[0]})'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment