Unverified Commit 18962129 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add license header and typehints for NAS (#4774)

parent 8c2f717d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import warnings
from typing import Callable, Dict, List, Union, Optional, Tuple
from typing import Callable, Dict, List, Union, Optional, Tuple, Sequence, cast
try:
from typing import Literal
except ImportError:
......@@ -193,8 +196,10 @@ class Cell(nn.Module):
def __init__(self,
op_candidates: Union[
Callable[[], List[nn.Module]],
List[Union[nn.Module, _cell_op_factory_type]],
Dict[str, Union[nn.Module, _cell_op_factory_type]]
List[nn.Module],
List[_cell_op_factory_type],
Dict[str, nn.Module],
Dict[str, _cell_op_factory_type]
],
num_nodes: int,
num_ops_per_node: int = 1,
......@@ -251,8 +256,8 @@ class Cell(nn.Module):
ops = self._convert_op_candidates(op_candidates, i, k, chosen)
# though it's layer choice and input choice here, in fixed mode, the chosen module will be created.
self.ops[-1].append(LayerChoice(ops, label=f'{self.label}/op_{i}_{k}'))
self.inputs[-1].append(inp)
cast(ModuleList, self.ops[-1]).append(LayerChoice(ops, label=f'{self.label}/op_{i}_{k}'))
cast(ModuleList, self.inputs[-1]).append(inp)
@property
def label(self):
......@@ -274,13 +279,17 @@ class Cell(nn.Module):
By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``)
of some of (possibly all) the nodes' outputs in the cell.
"""
processed_inputs: List[torch.Tensor]
if len(inputs) == 1 and isinstance(inputs[0], list):
inputs = inputs[0]
processed_inputs = list(inputs[0]) # shallow copy
else:
inputs = list(inputs)
assert len(inputs) == self.num_predecessors, 'The number of inputs must be equal to `num_predecessors`.'
states = self.preprocessor(inputs)
for ops, inps in zip(self.ops, self.inputs):
processed_inputs = cast(List[torch.Tensor], list(inputs))
assert len(processed_inputs) == self.num_predecessors, 'The number of inputs must be equal to `num_predecessors`.'
states: List[torch.Tensor] = self.preprocessor(processed_inputs)
for ops, inps in zip(
cast(Sequence[Sequence[LayerChoice]], self.ops),
cast(Sequence[Sequence[InputChoice]], self.inputs)
):
current_state = []
for op, inp in zip(ops, inps):
current_state.append(op(inp(states)))
......@@ -291,7 +300,7 @@ class Cell(nn.Module):
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
else:
this_cell = torch.cat([states[k] for k in self.output_node_indices], self.concat_dim)
return self.postprocessor(this_cell, inputs)
return self.postprocessor(this_cell, processed_inputs)
@staticmethod
def _convert_op_candidates(op_candidates, node_index, op_index, chosen) -> Union[Dict[str, nn.Module], List[nn.Module]]:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import warnings
from collections import OrderedDict
from typing import Callable, List, Union, Tuple, Optional
from typing import Callable, List, Dict, Union, Tuple, Optional
import torch
import torch.nn as nn
from nni.retiarii.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL
from .api import LayerChoice, ValueChoice, ValueChoiceX
from .api import LayerChoice, ValueChoice, ValueChoiceX, ChoiceOf
from .cell import Cell
from .nasbench101 import NasBench101Cell, NasBench101Mutator
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
......@@ -64,7 +67,7 @@ class Repeat(Mutable):
List[Callable[[int], nn.Module]],
nn.Module,
List[nn.Module]],
depth: Union[int, Tuple[int, int], ValueChoice], *, label: Optional[str] = None):
depth: Union[int, Tuple[int, int], ChoiceOf[int]], *, label: Optional[str] = None):
if isinstance(depth, tuple):
# we can't create a value choice here,
# otherwise we will have two value choices, one created here, another in init.
......@@ -90,7 +93,7 @@ class Repeat(Mutable):
List[Callable[[int], nn.Module]],
nn.Module,
List[nn.Module]],
depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None):
depth: Union[int, Tuple[int, int], ChoiceOf[int]], *, label: Optional[str] = None):
super().__init__()
self._label = None # by default, no label
......@@ -192,7 +195,7 @@ class NasBench201Cell(nn.Module):
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x)
def __init__(self, op_candidates: List[Callable[[int, int], nn.Module]],
def __init__(self, op_candidates: Union[Dict[str, Callable[[int, int], nn.Module]], List[Callable[[int, int], nn.Module]]],
in_features: int, out_features: int, num_tensors: int = 4,
label: Optional[str] = None):
super().__init__()
......@@ -214,16 +217,15 @@ class NasBench201Cell(nn.Module):
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
self.layers.append(node_ops)
def forward(self, inputs):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
The forward of input choice is simply selecting first on all choices.
It shouldn't be called directly by users in most cases.
"""
tensors = [inputs]
tensors: List[torch.Tensor] = [inputs]
for layer in self.layers:
current_tensor = []
for i, op in enumerate(layer):
current_tensor.append(op(tensors[i]))
current_tensor = torch.sum(torch.stack(current_tensor), 0)
tensors.append(current_tensor)
current_tensor: List[torch.Tensor] = []
for i, op in enumerate(layer): # type: ignore
current_tensor.append(op(tensors[i])) # type: ignore
tensors.append(torch.sum(torch.stack(current_tensor), 0))
return tensors[-1]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from packaging.version import Version
import torch
import torch.nn as nn
......@@ -233,7 +235,7 @@ class AutoActivation(nn.Module):
-----
Current `beta` is not per-channel parameter.
"""
def __init__(self, unit_num: int = 1, label: str = None):
def __init__(self, unit_num: int = 1, label: str | None = None):
super().__init__()
self._label = generate_new_label(label)
self.unaries = nn.ModuleList()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Optional, Tuple, Union
import torch.nn as nn
......@@ -41,7 +44,7 @@ def generate_new_label(label: Optional[str]):
return label
def get_fixed_value(label: str) -> Any:
def get_fixed_value(label: Optional[str]) -> Any:
ret = get_current_context('fixed')
try:
return ret[generate_new_label(label)]
......@@ -49,7 +52,7 @@ def get_fixed_value(label: str) -> Any:
raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}')
def get_fixed_dict(label_prefix: str) -> Tuple[str, Any]:
def get_fixed_dict(label_prefix: Optional[str]) -> Tuple[str, Any]:
ret = get_current_context('fixed')
try:
label_prefix = generate_new_label(label_prefix)
......
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import inspect
from typing import Any, List, Optional, Tuple, Dict, Iterator
from typing import Any, List, Optional, Tuple, Dict, Iterator, Iterable, cast
import torch.nn as nn
......@@ -28,12 +28,14 @@ class LayerChoiceMutator(Mutator):
# Each layer choice corresponds to a cell, which is unconnected in the base graph.
# We add the connections here in the mutation logic.
# Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
target = model.graphs[node.operation.cell_name]
target = model.graphs[cast(Cell, node.operation).cell_name]
chosen_node = target.get_node_by_name(chosen)
assert chosen_node is not None
target.add_edge((target.input_node, 0), (chosen_node, None))
target.add_edge((chosen_node, None), (target.output_node, None))
model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name))
operation = cast(Cell, node.operation)
target_node = cast(Node, model.get_node_by_name(node.name))
target_node.update_operation(Cell(operation.cell_name))
# remove redundant nodes
for rm_node in list(target.hidden_nodes): # remove from a list on the fly will cause issues
......@@ -57,7 +59,7 @@ class InputChoiceMutator(Mutator):
else:
chosen = [self.choice(candidates) for _ in range(n_chosen)]
for node in self.nodes:
target = model.get_node_by_name(node.name)
target = cast(Node, model.get_node_by_name(node.name))
target.update_operation('__torch__.nni.retiarii.nn.pytorch.ChosenInputs',
{'chosen': chosen, 'reduction': node.operation.parameters['reduction']})
......@@ -74,7 +76,7 @@ class ValueChoiceMutator(Mutator):
# no need to support transformation here,
# because it is naturally done in forward loop
for node in self.nodes:
target = model.get_node_by_name(node.name)
target = cast(Node, model.get_node_by_name(node.name))
target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen})
......@@ -86,7 +88,7 @@ class ParameterChoiceLeafMutator(Mutator):
super().__init__(label=label)
self.candidates = candidates
def mutate(self, model: Model) -> Model:
def mutate(self, model: Model) -> None:
# leave a record here
# real mutations will be done in ParameterChoiceMutator
self.choice(self.candidates)
......@@ -103,7 +105,7 @@ class ParameterChoiceMutator(Mutator):
self.nodes = nodes
def mutate(self, model: Model) -> Model:
def mutate(self, model: Model) -> None:
# looks like {"label1": "cat", "label2": 123}
value_choice_decisions = {}
for mutation in model.history:
......@@ -122,7 +124,7 @@ class ParameterChoiceMutator(Mutator):
result_value = value_choice.evaluate(leaf_node_values)
# update model with graph mutation primitives
target = model.get_node_by_name(node.name)
target = cast(Node, model.get_node_by_name(node.name))
target.update_operation(target.operation.type, {**target.operation.parameters, argname: result_value})
......@@ -138,20 +140,20 @@ class RepeatMutator(Mutator):
while u != graph.output_node:
if u != graph.input_node:
chain.append(u)
assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successor}.'
assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successors}.'
u = u.successors[0]
return chain
def mutate(self, model):
for node in self.nodes:
# the logic here is similar to layer choice. We find cell attached to each node.
target: Graph = model.graphs[node.operation.cell_name]
target: Graph = model.graphs[cast(Cell, node.operation).cell_name]
chain = self._retrieve_chain_from_graph(target)
# and we get the chosen depth (by value choice)
node_in_model = model.get_node_by_name(node.name)
node_in_model = cast(Node, model.get_node_by_name(node.name))
# depth is a value choice in base model
# but it's already mutated by a ParameterChoiceMutator here
chosen_depth = node_in_model.operation.parameters['depth']
chosen_depth: int = node_in_model.operation.parameters['depth']
for edge in chain[chosen_depth - 1].outgoing_edges:
edge.remove()
target.add_edge((chain[chosen_depth - 1], None), (target.output_node, None))
......@@ -159,8 +161,11 @@ class RepeatMutator(Mutator):
for edge in rm_node.outgoing_edges:
edge.remove()
rm_node.remove()
# to delete the unused parameters.
model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name))
target_node = cast(Node, model.get_node_by_name(node.name))
cell_operation = cast(Cell, node.operation)
target_node.update_operation(Cell(cell_operation.cell_name))
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
......@@ -241,7 +246,7 @@ class ManyChooseManyMutator(Mutator):
Choose based on labels. Will not affect the model itself.
"""
def __init__(self, label: Optional[str]):
def __init__(self, label: str):
super().__init__(label=label)
@staticmethod
......@@ -257,7 +262,7 @@ class ManyChooseManyMutator(Mutator):
return node.operation.parameters['n_chosen']
return 1
def mutate(self, model: Model):
def mutate(self, model: Model) -> None:
# this mutate does not have any effect, but it is recorded in the mutation history
for node in model.get_nodes_by_label(self.label):
n_chosen = self.number_of_chosen(node)
......@@ -280,12 +285,12 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
if not is_model_wrapped(pytorch_model):
raise ValueError('Please annotate the model with @model_wrapper decorator in python execution mode '
'if your model has init parameters.')
model.python_init_params = pytorch_model.trace_kwargs
model.python_init_params = cast(dict, pytorch_model.trace_kwargs)
else:
model.python_init_params = {}
# hyper-parameter choice
namespace: ModelNamespace = pytorch_model._model_namespace
namespace: ModelNamespace = cast(ModelNamespace, pytorch_model._model_namespace)
for param_spec in namespace.parameter_specs:
assert param_spec.categorical and param_spec.type == 'choice'
node = graph.add_node(f'param_spec_{param_spec.name}', 'ModelParameterChoice', {'candidates': param_spec.values})
......@@ -294,7 +299,8 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
for name, module in pytorch_model.named_modules():
# tricky case: value choice that serves as parameters are stored in traced arguments
if is_basic_unit(module):
for key, value in module.trace_kwargs.items():
trace_kwargs = cast(Dict[str, Any], module.trace_kwargs)
for key, value in trace_kwargs.items():
if isinstance(value, ValueChoiceX):
for i, choice in enumerate(value.inner_choices()):
node = graph.add_node(f'{name}.init.{key}.{i}', 'ValueChoice', {'candidates': choice.candidates})
......@@ -329,14 +335,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
mutators = []
mutators_final = []
for nodes in _group_by_label_and_type(graph.hidden_nodes):
label = nodes[0].label
assert label is not None, f'label of {nodes[0]} can not be None.'
assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \
f'Node with label "{nodes[0].label}" does not all have the same type.'
f'Node with label "{label}" does not all have the same type.'
assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \
f'Node with label "{nodes[0].label}" does not agree on parameters.'
f'Node with label "{label}" does not agree on parameters.'
if nodes[0].operation.type == 'NasBench101Cell':
mutators_final.append(NasBench101Mutator(nodes[0].label))
# The mutation of Nas-bench-101 is special, and has to be done lastly.
mutators_final.append(NasBench101Mutator(label))
else:
mutators.append(ManyChooseManyMutator(nodes[0].label))
mutators.append(ManyChooseManyMutator(label))
return model, mutators + mutators_final
......@@ -350,7 +359,7 @@ class EvaluatorValueChoiceLeafMutator(Mutator):
super().__init__(label=label)
self.candidates = candidates
def mutate(self, model: Model) -> Model:
def mutate(self, model: Model) -> None:
# leave a record here
# real mutations will be done in ParameterChoiceMutator
self.choice(self.candidates)
......@@ -388,7 +397,7 @@ class EvaluatorValueChoiceMutator(Mutator):
return obj
def mutate(self, model: Model):
def mutate(self, model: Model) -> None:
value_choice_decisions = {}
for mutation in model.history:
if isinstance(mutation.mutator, EvaluatorValueChoiceLeafMutator):
......@@ -454,7 +463,7 @@ def _is_all_equal(lst):
return True
def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]:
def _group_by_label_and_type(nodes: Iterable[Node]) -> List[List[Node]]:
result = {}
for node in nodes:
key = (node.label, node.operation.type)
......@@ -464,7 +473,7 @@ def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]:
return list(result.values())
def _group_by_label(nodes: List[Node]) -> List[List[Node]]:
def _group_by_label(nodes: Iterable[Node]) -> List[List[Node]]:
result = {}
for node in nodes:
label = node.operation.parameters['label']
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import OrderedDict
from typing import Callable, List, Optional, Union, Dict
from typing import Callable, List, Optional, Union, Dict, Tuple, cast
import numpy as np
import torch
......@@ -89,7 +92,7 @@ def compute_vertex_channels(input_channels, output_channels, matrix):
return vertex_channels
def prune(matrix, ops):
def prune(matrix, ops) -> Tuple[np.ndarray, List[Union[str, Callable[[int], nn.Module]]]]:
"""
Prune the extraneous parts of the graph.
......@@ -152,11 +155,17 @@ class _NasBench101CellFixed(nn.Module):
assert num_nodes == len(operations) + 2 == len(adjacency_list) + 1
self.operations = ['IN'] + operations + ['OUT'] # add psuedo nodes
raw_operations: List[Union[str, Callable[[int], nn.Module]]] = list(operations)
del operations # operations is no longer needed. Delete it to avoid misuse
# add psuedo nodes
raw_operations.insert(0, 'IN')
raw_operations.append('OUT')
self.connection_matrix = self.build_connection_matrix(adjacency_list, num_nodes)
del num_nodes # raw number of nodes is no longer used
self.connection_matrix, self.operations = prune(self.connection_matrix, self.operations)
self.connection_matrix, self.operations = prune(self.connection_matrix, raw_operations)
self.hidden_features = compute_vertex_channels(in_features, out_features, self.connection_matrix)
......@@ -172,7 +181,8 @@ class _NasBench101CellFixed(nn.Module):
self.projections.append(projection(in_features, self.hidden_features[i]))
for i in range(1, self.num_nodes - 1):
self.ops.append(operations[i - 1](self.hidden_features[i]))
operation = cast(Callable[[int], nn.Module], self.operations[i])
self.ops.append(operation(self.hidden_features[i]))
@staticmethod
def build_connection_matrix(adjacency_list, num_nodes):
......@@ -361,7 +371,7 @@ class NasBench101Mutator(Mutator):
# for validation purposes
# for python execution engine
def __init__(self, label: Optional[str]):
def __init__(self, label: str):
super().__init__(label=label)
@staticmethod
......@@ -378,9 +388,11 @@ class NasBench101Mutator(Mutator):
return 1
def mutate(self, model: Model):
max_num_edges = cast(int, None)
for node in model.get_nodes_by_label(self.label):
max_num_edges = node.operation.parameters['max_num_edges']
break
assert max_num_edges is not None
mutation_dict = {mut.mutator.label: mut.samples for mut in model.history}
num_nodes = mutation_dict[f'{self.label}/num_nodes'][0]
adjacency_list = [mutation_dict[f'{self.label}/input{i}'] for i in range(1, num_nodes)]
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import inspect
import warnings
from pathlib import Path
import torch
import torch.nn as nn
# To make auto-completion happy, we generate a _nn.py that lists out all the classes.
nn_cache_file_path = Path(__file__).parent / '_nn.py'
cache_valid = False
# Update this when cache format changes, to enforce an update.
cache_version = 2
def validate_cache() -> bool:
import torch
cache_valid = []
if nn_cache_file_path.exists():
lines = nn_cache_file_path.read_text().splitlines()
for line in lines:
if line.startswith('# _torch_version'):
_cached_torch_version = line[line.find('=') + 1:].strip()
if _cached_torch_version == torch.__version__:
cache_valid.append(True)
if line.startswith('# _torch_nn_cache_version'):
_cached_cache_version = int(line[line.find('=') + 1:].strip())
if _cached_cache_version == cache_version:
cache_valid.append(True)
return len(cache_valid) >= 2 and all(cache_valid)
if nn_cache_file_path.exists():
from . import _nn # pylint: disable=no-name-in-module
# valid only when torch version match
if _nn._torch_version == torch.__version__:
cache_valid = True
if not cache_valid:
def generate_stub_file() -> str:
import inspect
import warnings
import torch
import torch.nn as nn
_NO_WRAP_CLASSES = [
# not an nn.Module
'Parameter',
......@@ -47,7 +63,10 @@ if not cache_valid:
'# This file is auto-generated to make auto-completion work.',
'# When pytorch version does not match, it will get automatically updated.',
'# pylint: skip-file',
f'_torch_version = "{torch.__version__}"',
'# pyright: reportGeneralTypeIssues=false',
f'# _torch_version = {torch.__version__}',
f'# _torch_nn_cache_version = {cache_version}',
'import typing',
'import torch.nn as nn',
'from nni.retiarii.serializer import basic_unit',
]
......@@ -66,10 +85,9 @@ if not cache_valid:
'It means your PyTorch version might not be supported.', RuntimeWarning)
code.append(f'{name} = nn.{name}')
elif name in _WRAP_WITHOUT_TAG_CLASSES:
code.append(f'{name} = basic_unit(nn.{name}, basic_unit_tag=False)')
code.append(f'{name} = typing.cast(typing.Type[nn.{name}], basic_unit(nn.{name}, basic_unit_tag=False))')
else:
code.append(f'{name} = basic_unit(nn.{name})')
code.append(f'{name} = typing.cast(typing.Type[nn.{name}], basic_unit(nn.{name}))')
all_names.append(name)
elif inspect.isfunction(obj) or inspect.ismodule(obj):
......@@ -78,12 +96,19 @@ if not cache_valid:
code.append(f'__all__ = {all_names}')
return '\n'.join(code)
def write_cache(code: str) -> None:
with nn_cache_file_path.open('w') as fp:
fp.write('\n'.join(code))
fp.write(code)
code = generate_stub_file()
if not validate_cache():
write_cache(code)
# Import all modules from generated _nn.py
del Path, validate_cache, write_cache, cache_version, nn_cache_file_path, code
from . import _nn # pylint: disable=no-name-in-module
__all__ = _nn.__all__
from ._nn import * # pylint: disable=import-error, wildcard-import
from ._nn import * # pylint: disable=import-error, wildcard-import, unused-wildcard-import
......@@ -20,7 +20,7 @@ from .supermodule.base import BaseSuperNetModule
__all__ = ['MutationHook', 'BaseSuperNetModule', 'BaseOneShotLightningModule', 'traverse_and_mutate_submodules']
MutationHook = Callable[[nn.Module, str, Dict[str, Any]], Union[nn.Module, bool, Tuple[nn.Module, bool]]]
MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[nn.Module, bool, Tuple[nn.Module, bool]]]
def traverse_and_mutate_submodules(
......@@ -149,11 +149,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
The hook list will be appended by ``default_mutation_hooks`` in each one-shot module.
To be more specific, the input arguments are three arguments:
To be more specific, the input arguments are four arguments:
#. a module that might be processed,
#. name of the module in its parent module,
#. a memo dict whose usage depends on the particular algorithm.
#. keyword arguments (configurations).
Note that the memo should be read/written by hooks.
There won't be any hooks called on root module.
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# type: ignore
import copy
import logging
from collections import OrderedDict
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# type: ignore
import logging
import torch
......
......@@ -27,7 +27,7 @@ which interprets the slice and apply it on a tensor.
"""
import operator
from typing import Tuple, Union, List, Dict, Callable, Optional, Iterator, TypeVar, Any, Generic
from typing import Tuple, Union, List, Dict, Callable, Optional, Iterator, TypeVar, Any, Generic, cast
import numpy as np
import torch
......@@ -128,9 +128,10 @@ class Slicable(Generic[T]):
raise TypeError(f'Unsuppoted weight type: {type(weight)}')
self.weight = weight
def __getitem__(self, index: multidim_slice) -> T:
def __getitem__(self, index: Union[slice_type, multidim_slice]) -> T:
if not isinstance(index, tuple):
index = (index, )
index = cast(multidim_slice, index)
# Get the dict value in index's leafs
# There can be at most one dict
......
......@@ -24,7 +24,7 @@ class BaseSuperNetModule(nn.Module):
rather than their compositions.
"""
def resample(self, memo: Dict[str, Any] = None) -> Dict[str, Any]:
def resample(self, memo: Dict[str, Any]) -> Dict[str, Any]:
"""
Resample the super-net module.
......@@ -40,7 +40,7 @@ class BaseSuperNetModule(nn.Module):
"""
raise NotImplementedError()
def export(self, memo: Dict[str, Any] = None) -> Dict[str, Any]:
def export(self, memo: Dict[str, Any]) -> Dict[str, Any]:
"""
Export the final architecture within this module.
It should have the same keys as ``search_space_spec()``.
......
......@@ -275,11 +275,11 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
if not arch:
yield name, p
def resample(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
def resample(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
"""Differentiable. Do nothing in resample."""
return {}
def export(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
def export(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
"""Export is also random for each leaf value choice."""
result = {}
for name, spec in operation.search_space_spec().items():
......
......@@ -8,11 +8,12 @@ which is commonly known as super-kernel (as in channel search), or weight entang
import inspect
import itertools
from typing import Union, Tuple, Dict, List, Any, Type, Optional, TypeVar
from typing import Union, Tuple, Dict, List, Any, Type, Optional, TypeVar, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import nni.retiarii.nn.pytorch as retiarii_nn
from nni.common.hpo_utils import ParameterSpec
......@@ -46,11 +47,11 @@ class MixedOperationSamplingPolicy:
"""
pass
def resample(self, operation: 'MixedOperation', memo: Dict[str, Any] = None) -> Dict[str, Any]:
def resample(self, operation: 'MixedOperation', memo: Dict[str, Any]) -> Dict[str, Any]:
"""The handler of :meth:`MixedOperation.resample`."""
raise NotImplementedError()
def export(self, operation: 'MixedOperation', memo: Dict[str, Any] = None) -> Dict[str, Any]:
def export(self, operation: 'MixedOperation', memo: Dict[str, Any]) -> Dict[str, Any]:
"""The handler of :meth:`MixedOperation.export`."""
raise NotImplementedError()
......@@ -513,43 +514,42 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
embed_dim = _W(embed_dim)
# in projection weights & biases has q, k, v weights concatenated together
in_proj_bias = in_proj_weight = None
in_proj_bias: Optional[Tensor] = None
in_proj_weight: Optional[Tensor] = None
if self.in_proj_bias is not None:
in_proj_bias = _S(self.in_proj_bias)[self._to_proj_slice(embed_dim)]
in_proj_bias = _S(cast(Tensor, self.in_proj_bias))[self._to_proj_slice(embed_dim)]
if self.in_proj_weight is not None:
in_proj_weight = _S(self.in_proj_weight)[self._to_proj_slice(embed_dim), :embed_dim]
in_proj_weight = _S(cast(Tensor, self.in_proj_weight))[self._to_proj_slice(embed_dim), :embed_dim]
bias_k = _S(self.bias_k)[:, :, :embed_dim] if self.bias_k is not None else None
bias_v = _S(self.bias_v)[:, :, :embed_dim] if self.bias_v is not None else None
out_proj_weight = _S(self.out_proj.weight)[:embed_dim, :embed_dim]
out_proj_bias = _S(self.out_proj.bias)[:embed_dim] if self.out_proj.bias is not None else None
bias_k = _S(cast(Tensor, self.bias_k))[:, :, :embed_dim] if self.bias_k is not None else None
bias_v = _S(cast(Tensor, self.bias_v))[:, :, :embed_dim] if self.bias_v is not None else None
out_proj_weight = _S(cast(Tensor, self.out_proj.weight))[:embed_dim, :embed_dim]
out_proj_bias = _S(cast(Tensor, self.out_proj.bias))[:embed_dim] if self.out_proj.bias is not None else None
if not qkv_same_embed_dim:
kdim = _W(kdim)
vdim = _W(vdim)
q_proj = _S(self.q_proj_weight)[:embed_dim, :embed_dim]
k_proj = _S(self.k_proj_weight)[:embed_dim]
k_proj = _S(k_proj)[:, :kdim]
v_proj = _S(self.v_proj_weight)[:embed_dim]
v_proj = _S(v_proj)[:, :vdim]
q_proj = _S(cast(Tensor, self.q_proj_weight))[:embed_dim, :embed_dim]
k_proj = _S(cast(Tensor, self.k_proj_weight))[:embed_dim]
k_proj = _S(k_proj)[:, :_W(kdim)]
v_proj = _S(cast(Tensor, self.v_proj_weight))[:embed_dim]
v_proj = _S(v_proj)[:, :_W(vdim)]
# The rest part is basically same as pytorch
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, used_embed_dim, num_heads,
in_proj_weight, in_proj_bias,
cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias),
bias_k, bias_v, self.add_zero_attn,
dropout, out_proj_weight, out_proj_bias,
dropout, out_proj_weight, cast(Tensor, out_proj_bias),
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=q_proj, k_proj_weight=k_proj, v_proj_weight=v_proj)
else:
# Cast tensor here because of a bug in pytorch stub
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, used_embed_dim, num_heads,
in_proj_weight, in_proj_bias,
cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias),
bias_k, bias_v, self.add_zero_attn,
dropout, out_proj_weight, out_proj_bias,
dropout, out_proj_weight, cast(Tensor, out_proj_bias),
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask)
......
......@@ -9,7 +9,7 @@ The support remains limited. Known limitations include:
- The code contains duplicates. Needs refactor.
"""
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, cast
import torch
import torch.nn as nn
......@@ -94,7 +94,7 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
self._sample_idx = self.op_names.index(self._sampled)
else:
probs = self._softmax(self._arch_alpha)
self._sample_idx = torch.multinomial(probs, 1)[0].item()
self._sample_idx = int(torch.multinomial(probs, 1)[0].item())
self._sampled = self.op_names[self._sample_idx]
# set binary gates
......@@ -109,10 +109,11 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
"""Chose the argmax if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self.op_names[torch.argmax(self._arch_alpha).item()]}
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def finalize_grad(self):
binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
......@@ -164,13 +165,13 @@ class ProxylessMixedInput(DifferentiableMixedInput):
else:
probs = self._softmax(self._arch_alpha)
sample = torch.multinomial(probs, 1)[0].item()
self._sampled = sample
self._sampled = int(sample)
# set binary gates
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[sample] = 1.0
self._binary_gates.data[cast(int, self._sampled)] = 1.0
return {self.label: self._sampled}
......@@ -182,6 +183,7 @@ class ProxylessMixedInput(DifferentiableMixedInput):
def finalize_grad(self):
binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
......
......@@ -129,6 +129,8 @@ class PathSamplingInput(BaseSuperNetModule):
if isinstance(module, InputChoice):
if module.reduction not in ['sum', 'mean', 'concat']:
raise ValueError('Only input choice of sum/mean/concat reduction is supported.')
if module.n_chosen is None:
raise ValueError('n_chosen is None is not supported yet.')
return cls(module.n_candidates, module.n_chosen, module.reduction, module.label)
def forward(self, input_tensors):
......@@ -161,7 +163,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
self._sampled: Optional[Dict[str, Any]] = None
def resample(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
def resample(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
"""Random sample for each leaf value choice."""
result = {}
space_spec = operation.search_space_spec()
......@@ -179,7 +181,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
return result
def export(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
def export(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
"""Export is also random for each leaf value choice."""
result = {}
space_spec = operation.search_space_spec()
......
......@@ -132,7 +132,7 @@ def _replace_module_with_type(root_module, init_fn, type_name, modules):
for name, child in m.named_children():
if isinstance(child, type_name):
setattr(m, name, init_fn(child))
modules.append((child.key, getattr(m, name)))
modules.append((child.label, getattr(m, name)))
else:
apply(child)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, Dict, List)
from typing import (Any, Dict, List, Optional, cast)
from . import debug_configs
......@@ -34,6 +34,8 @@ class Operation:
Arbitrary key-value parameters (e.g. kernel_size).
"""
io_names: List[str] = []
def __init__(self, type_name: str, parameters: Dict[str, Any] = {}, _internal: bool = False, attributes: Dict[str, Any] = {}):
assert _internal, '`Operation()` is private, use `Operation.new()` instead'
self.type: str = type_name
......@@ -43,7 +45,7 @@ class Operation:
def to_init_code(self, field: str) -> str:
raise NotImplementedError()
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
raise NotImplementedError()
def _to_class_name(self) -> str:
......@@ -53,8 +55,8 @@ class Operation:
return True
@staticmethod
def new(type_name: str, parameters: Dict[str, Any] = None, cell_name: str = None,
attributes: Dict[str, Any] = None) -> 'Operation':
def new(type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None), cell_name: str = cast(str, None),
attributes: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Operation':
parameters = parameters or {}
attributes = attributes or {}
if type_name == '_cell':
......@@ -98,16 +100,16 @@ class PyTorchOperation(Operation):
subclass_name = 'FunctionalOperator'
for subclass in cls.__subclasses__():
if hasattr(subclass, '_ori_type_name') and \
subclass_name in subclass._ori_type_name:
subclass_name in cast(Any, subclass)._ori_type_name:
return subclass
for subclass in cls.__subclasses__():
if hasattr(subclass, '_artificial_op_name') and \
subclass_name in subclass._artificial_op_name:
subclass_name in cast(Any, subclass)._artificial_op_name:
return subclass
return cls
@classmethod
def to_class_name(cls, type_name) -> str:
def to_class_name(cls, type_name) -> Optional[str]:
if type_name.startswith('__torch__.'):
return type_name[len('__torch__.'):]
elif type_name.startswith('__mutated__.'):
......@@ -119,7 +121,7 @@ class PyTorchOperation(Operation):
def is_functional(cls, type_name) -> bool:
return type_name.startswith('Function.')
def _to_class_name(self) -> str:
def _to_class_name(self) -> Optional[str]:
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):]
elif self.type.startswith('__mutated__.'):
......@@ -127,7 +129,7 @@ class PyTorchOperation(Operation):
else:
return None
def get_import_pkg(self) -> str:
def get_import_pkg(self) -> Optional[str]:
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):].split('.')[0]
elif self.type.startswith('__mutated__.'):
......@@ -135,14 +137,14 @@ class PyTorchOperation(Operation):
else:
return None
def to_init_code(self, field: str) -> str:
def to_init_code(self, field: str) -> Optional[str]:
if self._to_class_name() is not None:
assert 'positional_args' not in self.parameters
kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items())
return f'self.{field} = {self._to_class_name()}({kw_params})'
return None
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
"""
Parameters
----------
......@@ -207,7 +209,9 @@ class Cell(PyTorchOperation):
No real usage. Exists for compatibility with base class.
"""
def __init__(self, cell_name: str, parameters: Dict[str, Any] = None, attributes: Dict[str, Any] = None):
def __init__(self, cell_name: str,
parameters: Dict[str, Any] = cast(Dict[str, Any], None),
attributes: Dict[str, Any] = cast(Dict[str, Any], None)):
self.type = '_cell'
self.cell_name = cell_name
self.parameters = parameters or {}
......@@ -217,7 +221,7 @@ class Cell(PyTorchOperation):
# TODO: ugly, think about how to refactor this part
return _convert_name(self.cell_name)
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
class _IOPseudoOperation(Operation):
......@@ -227,7 +231,7 @@ class _IOPseudoOperation(Operation):
especially in static type checking.
"""
def __init__(self, type_name: str, io_names: List = None):
def __init__(self, type_name: str, io_names: List[str] = cast(List[str], None)):
assert type_name.startswith('_')
super(_IOPseudoOperation, self).__init__(type_name, {}, True)
self.io_names = io_names
......@@ -235,7 +239,7 @@ class _IOPseudoOperation(Operation):
def to_init_code(self, field: str) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def __bool__(self) -> bool:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Definition of operation types.
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from typing import (Any, Dict, List)
import torch
import torch.nn.functional as nn_functional
from ..operation import PyTorchOperation
......@@ -39,23 +42,23 @@ class NoOpIdentity(PyTorchOperation):
"""
_ori_type_name = ['noop_identity']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {", ".join(inputs)}'
class ModuleOperator(PyTorchOperation):
_ori_type_name = ['ModuleOperator', 'shared']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
class FunctionalOperator(PyTorchOperation):
_ori_type_name = ['FunctionalOperator']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
func_name = self.type[len('Function.'):]
if not hasattr(torch.nn.functional, func_name):
if not hasattr(nn_functional, func_name):
raise RuntimeError('For now, we only support calling independent functions from `torch.nn.functional`, '
f'{func_name} is not in it.')
return f'{output} = F.{func_name}({", ".join(inputs)})'
......@@ -64,7 +67,7 @@ class FunctionalOperator(PyTorchOperation):
class PrimConstant(PyTorchOperation):
_ori_type_name = ['prim::Constant']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
if self.parameters['type'] in ['None', 'NoneType']:
......@@ -87,28 +90,28 @@ class PrimConstant(PyTorchOperation):
class PrimListConstruct(PyTorchOperation):
_ori_type_name = ['prim::ListConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = [{", ".join(inputs)}]'
class PrimListUnpack(PyTorchOperation):
_ori_type_name = ['prim::ListUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}'
class PrimTupleConstruct(PyTorchOperation):
_ori_type_name = ['prim::TupleConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = ({", ".join(inputs)})'
class PrimTupleUnpack(PyTorchOperation):
_ori_type_name = ['prim::TupleUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# have single output here, because the following code uses index to access the unpacked values
assert len(inputs) == 1
return f'{output} = {inputs[0]}'
......@@ -117,7 +120,7 @@ class PrimTupleUnpack(PyTorchOperation):
class PrimGetAttr(PyTorchOperation):
_ori_type_name = ['prim::GetAttr']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.parameters['value'] is not None:
return f"{output} = {self.parameters['value']}"
else:
......@@ -127,14 +130,14 @@ class PrimGetAttr(PyTorchOperation):
class PrimUncheckedCast(PyTorchOperation):
_ori_type_name = ['prim::unchecked_cast']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}'
class SimpleMember(PyTorchOperation):
_ori_type_name = ['prim::is_cuda', 'prim::data']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
member_name = self.type.split('::')[-1]
return f'{output} = {inputs[0]}.{member_name}'
......@@ -142,16 +145,16 @@ class SimpleMember(PyTorchOperation):
class AtenContiguous(PyTorchOperation):
_ori_type_name = ['aten::contiguous']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# defined in pytorch/c10/core/MemoryFormat.h
assert inputs_value[1] in [0, 1, 2]
assert inputs_value is not None and inputs_value[1] in [0, 1, 2]
return f'{output} = {inputs[0]}.contiguous(memory_format={mem_format[inputs_value[1]]})'
class AtenGetitem(PyTorchOperation):
_ori_type_name = ['aten::__getitem__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'{output} = {inputs[0]}[{inputs[1]}]'
......@@ -159,7 +162,7 @@ class AtenGetitem(PyTorchOperation):
class AtenAppend(PyTorchOperation):
_ori_type_name = ['aten::append']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
......@@ -167,7 +170,7 @@ class AtenAppend(PyTorchOperation):
class MergedSlice(PyTorchOperation):
_ori_type_name = ['MergedSlice']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if (len(inputs) - 1) % 4 == 0:
slices = []
dim = int((len(inputs) - 1) / 4)
......@@ -187,21 +190,21 @@ class MergedSlice(PyTorchOperation):
class AtenBool(PyTorchOperation):
_ori_type_name = ['aten::Bool']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = bool({inputs[0]})'
class AtenNot(PyTorchOperation):
_ori_type_name = ['aten::__not__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = not {inputs[0]}'
class AtenCat(PyTorchOperation):
_ori_type_name = ['aten::cat']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
......@@ -215,7 +218,7 @@ class AtenTensors(PyTorchOperation):
'aten::new_empty', 'aten::new_zeros', 'aten::arange',
'aten::tensor', 'aten::ones', 'aten::zeros', 'aten::as_tensor']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
schemas = torch._C._jit_get_schemas_for_operator(self.type)
# match number of inputs
overloaded_defs = [len(s.arguments) for s in schemas]
......@@ -257,40 +260,41 @@ class AtenTensors(PyTorchOperation):
class AtenFloordiv(PyTorchOperation):
_ori_type_name = ['aten::floordiv']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]} // {inputs[1]}'
class AtenMul(PyTorchOperation):
_ori_type_name = ['aten::mul']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]} * {inputs[1]}'
class AtenLen(PyTorchOperation):
_ori_type_name = ['aten::len']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = len({inputs[0]})'
class AtenIntImplicit(PyTorchOperation):
_ori_type_name = ['aten::IntImplicit', 'aten::Float', 'aten::Int', 'aten::ScalarImplicit']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.type.endswith('Implicit'):
return f'{output} = {inputs[0]}'
elif self.type == 'aten::Int':
return f'{output} = int({inputs[0]})'
elif self.type == 'aten::Float':
return f'{output} = float({inputs[0]})'
raise TypeError(f'Unexpected type: {self.type}')
class AtenIndex(PyTorchOperation):
_ori_type_name = ['aten::index']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}[{inputs[1]}]'
......@@ -355,13 +359,13 @@ def _get_tensor_ops():
def _get_torch_ops():
torch_op_args = {}
for mod in torch.jit._builtins._modules_containing_builtins:
for mod in torch.jit._builtins._modules_containing_builtins: # type: ignore
name = mod.__name__
if name == 'torch._C._nn':
continue
# only process 'torch.XXX'
for elem in dir(mod):
builtin = torch.jit._builtins._find_builtin(getattr(mod, elem))
builtin = torch.jit._builtins._find_builtin(getattr(mod, elem)) # type: ignore
if builtin is not None:
schemas = torch._C._jit_get_schemas_for_operator(builtin)
for schema in schemas:
......@@ -436,7 +440,7 @@ class TensorOps(PyTorchOperation):
return None
raise RuntimeError(f'tensor op type {_type} has no matched')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# TODO: deal with conditional ops
if self.type in TensorOps.comparison_ops:
return f'{output} = ({inputs[0]} {TensorOps.comparison_ops[self.type]} {inputs[1]})'
......@@ -486,7 +490,7 @@ class TorchOps(PyTorchOperation):
else:
raise RuntimeError(f'torch op type {_type} has no matched')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
matched_args = TorchOps._get_matched_args(self.type, inputs)
op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i]}' if t.startswith('Optional[') else f'{inputs[i]}'
......@@ -498,7 +502,7 @@ class AtenAvgpool2d(PyTorchOperation):
# NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name = ['aten::avg_pool2d']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = F.avg_pool2d({", ".join(inputs)})'
......@@ -506,7 +510,7 @@ class ToDevice(PyTorchOperation):
_artificial_op_name = "ToDevice"
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False,
attributes: Dict[str, Any] = None):
attributes: Dict[str, Any] = {}):
self.type = "ToDevice"
self.device = parameters['device']
self.overridden_device_repr = None
......@@ -540,5 +544,5 @@ class AtenDet(PyTorchOperation):
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
_ori_type_name = ['aten::linalg_det']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = torch.det({inputs[0]})'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment