Unverified Commit a0fd0036 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #5036 from microsoft/promote-retiarii-to-nas

[DO NOT SQUASH] Promote retiarii to NAS
parents d6dcb483 bc6d8796
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from enum import Enum
# except the special case which can not treat as a basic module from pytorch
MODULE_EXCEPT_LIST = ['Sequential']
class OpTypeName(str, Enum):
"""
op type to its type name str
"""
Attr = 'Attr'
Constant = 'Constant'
LayerChoice = 'LayerChoice'
InputChoice = 'InputChoice'
ValueChoice = 'ValueChoice'
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'
Repeat = 'Repeat'
Cell = 'Cell'
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional
from typing_extensions import TypeGuard
from nni.nas.execution.common import Cell, Model, Graph, Node, Edge
def build_full_name(prefix, name, seq=None):
if isinstance(name, list):
name = '__'.join(name)
if seq is None:
return '{}__{}'.format(prefix, name)
else:
return '{}__{}{}'.format(prefix, name, str(seq))
def build_python_name(prefix, name):
if isinstance(name, list):
name = '.'.join(name)
if prefix:
return '{}.{}'.format(prefix, name)
else: # predix could be None
return name
def build_cand_name(name, label):
return f'layerchoice_{label}_{name}'
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
def _extract_info_from_trace_node(trace_node):
"""
Extract parameters from a trace node.
Parameters
----------
trace_node: torch._C.Value
"""
input_shape = []
output_shape = []
inputs = list(trace_node.inputs())
# cat input tensors are in a strange place
if trace_node.kind() == 'aten::cat':
input_shape = [input.type().sizes() for input in inputs[0].node().inputs()]
else:
for _input in inputs:
input_type = _input.type()
if input_type.kind() == 'TensorType':
shape = input_type.sizes()
if shape:
input_shape.append(shape)
for _output in trace_node.outputs():
output_type = _output.type()
if output_type.kind() == 'TensorType':
shape = output_type.sizes()
if shape:
output_shape.append(shape)
shape_parameters = {
'input_shape': input_shape,
'output_shape': output_shape,
}
if trace_node.kind() == 'aten::cat':
parameters = {'dim': inputs[1].toIValue()}
return shape_parameters, parameters
else:
return shape_parameters, None
def is_layerchoice_node(ir_node: Optional[Node]) -> TypeGuard[Node]:
if ir_node is not None and isinstance(ir_node.operation, Cell) and ir_node.operation.parameters.get('mutation') == 'layerchoice':
return True
else:
return False
def get_full_name_by_scope_name(ir_model: Model, scope_names, prefix=''):
full_name = prefix
for last_scope in range(len(scope_names)):
ir_node = ir_model.get_node_by_name(full_name)
# check if it's layerchoice
if is_layerchoice_node(ir_node):
full_name = f'layerchoice_{ir_node.operation.parameters["label"]}_{scope_names[last_scope]}'
else:
full_name = build_full_name(full_name, scope_names[last_scope])
return full_name
def match_node(ir_model: Model, torch_node, prefix=''):
"""
Match the corresponding node of a torch._C.Value
"""
scope_names = torch_node.scopeName().split('/')[-1].split('.')[1:]
full_name = get_full_name_by_scope_name(ir_model, scope_names, prefix)
# handle the case when node is not nn.Module, but directly used in forward()
# Because name can't be directly matched, so I use a hacky way.
# I match the first unshaped node of that kind
graph = ir_model.graphs.get(full_name)
if graph is not None:
for node in graph.get_nodes_by_type(torch_node.kind()):
if not node.operation.attributes['input_shape']:
return node
return None
else:
return ir_model.get_node_by_name(full_name)
def _without_shape_info(node: Node):
return not node.operation.attributes['input_shape'] and not node.operation.attributes['output_shape']
def flatten_model_graph(ir_model: Model):
"""
Flatten the subgraph into root graph.
"""
def _flatten(graph: Graph):
"""
flatten this graph
"""
model = graph.model
node_to_remove = []
for node in graph.hidden_nodes:
node_graph = model.graphs.get(node.name)
if node_graph is not None:
_flatten(node_graph)
# flatten node graph into this graph
id_to_new_node = {}
for node_graph_node in node_graph.hidden_nodes:
new_node = Node(graph, node_graph_node.id, node_graph_node.name, node_graph_node.operation, _internal=True)
new_node.update_label(node_graph_node.label)
new_node._register()
id_to_new_node[new_node.id] = new_node
# reconnect node edges
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for input_node_edge in node_graph.input_node.outgoing_edges:
if input_node_edge.head_slot == in_edge.tail_slot:
graph.add_edge(
head=(in_edge.head, in_edge.head_slot),
tail=(id_to_new_node[input_node_edge.tail.id], input_node_edge.tail_slot))
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
for output_node_edge in node_graph.output_node.incoming_edges:
if output_node_edge.head_slot == out_edge.tail_slot:
graph.add_edge(
head=(id_to_new_node[output_node_edge.head.id], output_node_edge.head_slot),
tail=(out_edge.tail, out_edge.tail_slot))
for edge in node_graph.edges:
if edge.head == node_graph.input_node or edge.tail == node_graph.output_node:
continue
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
node_to_remove.append(node)
del model.graphs[node.name]
for node in node_to_remove:
node.remove()
new_ir_model = ir_model.fork()
_flatten(new_ir_model.root_graph)
# remove subgraphs
new_ir_model.graphs = {new_ir_model._root_graph_name: new_ir_model.root_graph}
return new_ir_model
def flatten_model_graph_without_layerchoice(ir_model: Model):
"""
Flatten the subgraph into root graph and jump all layerchoice
"""
def _flatten_without_layerchoice(graph: Graph):
"""
flatten this graph
"""
model = graph.model
node_to_remove = []
for node in graph.hidden_nodes:
if is_layerchoice_node(node):
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
del model.graphs[node.name]
node.remove()
return
node_graph = model.graphs.get(node.name)
if node_graph is not None:
_flatten_without_layerchoice(node_graph)
# flatten node graph into this graph
id_to_new_node = {}
for node_graph_node in node_graph.hidden_nodes:
new_node = Node(graph, node_graph_node.id, node_graph_node.name, node_graph_node.operation, _internal=True)
new_node.update_label(node_graph_node.label)
new_node._register()
id_to_new_node[new_node.id] = new_node
# reconnect node edges
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for input_node_edge in node_graph.input_node.outgoing_edges:
if input_node_edge.head_slot == in_edge.tail_slot:
graph.add_edge(
head=(in_edge.head, in_edge.head_slot),
tail=(id_to_new_node[input_node_edge.tail.id], input_node_edge.tail_slot))
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
for output_node_edge in node_graph.output_node.incoming_edges:
if output_node_edge.head_slot == out_edge.tail_slot:
graph.add_edge(
head=(id_to_new_node[output_node_edge.head.id], output_node_edge.head_slot),
tail=(out_edge.tail, out_edge.tail_slot))
for edge in node_graph.edges:
if edge.head == node_graph.input_node or edge.tail == node_graph.output_node:
continue
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
node_to_remove.append(node)
del model.graphs[node.name]
for node in node_to_remove:
node.remove()
new_ir_model = ir_model.fork()
_flatten_without_layerchoice(new_ir_model.root_graph)
# remove subgraphs
new_ir_model.graphs = {new_ir_model._root_graph_name: new_ir_model.root_graph}
return new_ir_model
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import graphviz
def convert_to_visualize(graph_ir, vgraph):
for name, graph in graph_ir.items():
if name == '_evaluator':
continue
with vgraph.subgraph(name='cluster'+name) as subgraph:
subgraph.attr(color='blue')
cell_node = {}
ioput = {'_inputs': '{}-{}'.format(name, '_'.join(graph['inputs'])),
'_outputs': '{}-{}'.format(name, '_'.join(graph['outputs']))}
subgraph.node(ioput['_inputs'])
subgraph.node(ioput['_outputs'])
for node_name, node_value in graph['nodes'].items():
value = node_value['operation']
if value['type'] == '_cell':
cell_input_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['inputs']))
cell_output_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['outputs']))
cell_node[node_name] = (cell_input_name, cell_output_name)
print('cell: ', node_name, cell_input_name, cell_output_name)
else:
subgraph.node(node_name)
for edge in graph['edges']:
src = edge['head'][0]
if src == '_inputs':
src = ioput['_inputs']
elif src in cell_node:
src = cell_node[src][1]
dst = edge['tail'][0]
if dst == '_outputs':
dst = ioput['_outputs']
elif dst in cell_node:
dst = cell_node[dst][0]
subgraph.edge(src, dst)
def visualize_model(graph_ir):
vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg')
convert_to_visualize(graph_ir, vgraph)
vgraph.render()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
__all__ = ['BaseGraphData', 'BaseExecutionEngine']
import logging
import os
import random
import string
from typing import Any, Dict, Iterable, List
from nni.experiment import rest
from nni.nas.execution.common import (
AbstractExecutionEngine, AbstractGraphListener, RetiariiAdvisor, get_mutation_summary,
Model, ModelStatus, MetricData, Evaluator,
send_trial, receive_trial_parameters, get_advisor
)
from nni.nas.utils import import_
from .codegen import model_to_pytorch_script
_logger = logging.getLogger(__name__)
class BaseGraphData:
"""
Data sent between strategy and trial, in graph-based execution engine.
Attributes
----------
model_script
code of an instantiated PyTorch model
evaluator
training approach for model_script
mutation_summary
a dict of all the choices during mutations in the HPO search space format
"""
def __init__(self, model_script: str, evaluator: Evaluator, mutation_summary: dict) -> None:
self.model_script = model_script
self.evaluator = evaluator
self.mutation_summary = mutation_summary
def dump(self) -> dict:
return {
'model_script': self.model_script,
# engine needs to call dump here,
# otherwise, evaluator will become binary
# also, evaluator can be none in tests
'evaluator': self.evaluator._dump() if self.evaluator is not None else None,
'mutation_summary': self.mutation_summary
}
@staticmethod
def load(data) -> 'BaseGraphData':
return BaseGraphData(data['model_script'], Evaluator._load(data['evaluator']), data['mutation_summary'])
class BaseExecutionEngine(AbstractExecutionEngine):
"""
The execution engine with no optimization at all.
Resource management is implemented in this class.
"""
def __init__(self, rest_port: int | None = None, rest_url_prefix: str | None = None) -> None:
"""
Upon initialization, advisor callbacks need to be registered.
Advisor will call the callbacks when the corresponding event has been triggered.
Base execution engine will get those callbacks and broadcast them to graph listener.
Parameters
----------
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self._history: List[Model] = []
self.resources = 0
# register advisor callbacks
advisor: RetiariiAdvisor = get_advisor()
advisor.register_callbacks({
'send_trial': self._send_trial_callback,
'request_trial_jobs': self._request_trial_jobs_callback,
'trial_end': self._trial_end_callback,
'intermediate_metric': self._intermediate_metric_callback,
'final_metric': self._final_metric_callback
})
def submit_models(self, *models: Model) -> None:
for model in models:
data = self.pack_model_data(model)
self._running_models[send_trial(data.dump())] = model
self._history.append(model)
def list_models(self) -> Iterable[Model]:
return self._history
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener)
def _send_trial_callback(self, paramater: dict) -> None:
if self.resources <= 0:
# FIXME: should be a warning message here
_logger.debug('There is no available resource, but trial is submitted.')
self.resources -= 1
_logger.debug('Resource used. Remaining: %d', self.resources)
def _request_trial_jobs_callback(self, num_trials: int) -> None:
self.resources += num_trials
_logger.debug('New resource available. Remaining: %d', self.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
for listener in self._listeners:
listener.on_training_end(model, success)
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
model = self._running_models[trial_id]
model.intermediate_metrics.append(metrics)
for listener in self._listeners:
listener.on_intermediate_metric(model, metrics)
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
model = self._running_models[trial_id]
model.metric = metrics
for listener in self._listeners:
listener.on_metric(model, metrics)
def query_available_resource(self) -> int:
return self.resources
def budget_exhausted(self) -> bool:
resp = rest.get(self.port, '/check-status', self.url_prefix)
return resp['status'] == 'DONE'
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation_summary = get_mutation_summary(model)
assert model.evaluator is not None, 'Model evaluator can not be None'
return BaseGraphData(model_to_pytorch_script(model), model.evaluator, mutation_summary) # type: ignore
@classmethod
def trial_execute_graph(cls) -> None:
"""
Initialize the model, hand it over to trainer.
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(graph_data.model_script)
model_cls = import_(f'_generated_model.{random_str}._model')
graph_data.evaluator._execute(model_cls)
os.remove(file_name)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from typing import (Any, Dict, List)
import torch
import torch.nn.functional as nn_functional
from nni.nas.execution.common import PyTorchOperation
mem_format = [
'torch.contiguous_format', # 0
'torch.preserve_format', # 1
'torch.channels_last', # 2
]
# this snippet is copied from torch/onnx/symbolic_helper.py,
# the original definition is in c10/core/ScalarType.h
# This indicates each scalar type's corresponding
scalar_type_to_pytorch_type = [
'torch.uint8', # 0
'torch.int8', # 1
'torch.short', # 2
'torch.int', # 3
'torch.int64', # 4
'torch.half', # 5
'torch.float', # 6
'torch.double', # 7
'torch.complex32', # 8
'torch.complex64', # 9
'torch.complex128', # 10
'torch.bool', # 11
]
class NoOpIdentity(PyTorchOperation):
"""
this operator type is added by us
"""
_ori_type_name = ['noop_identity']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {", ".join(inputs)}'
class ModuleOperator(PyTorchOperation):
_ori_type_name = ['ModuleOperator', 'shared']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
class FunctionalOperator(PyTorchOperation):
_ori_type_name = ['FunctionalOperator']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
func_name = self.type[len('Function.'):]
if not hasattr(nn_functional, func_name):
raise RuntimeError('For now, we only support calling independent functions from `torch.nn.functional`, '
f'{func_name} is not in it.')
return f'{output} = F.{func_name}({", ".join(inputs)})'
class PrimConstant(PyTorchOperation):
_ori_type_name = ['prim::Constant']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
if self.parameters['type'] in ['None', 'NoneType']:
return f'{output} = None'
elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'): # 'Long()' ???
return f'{output} = {self.parameters["value"]}'
elif self.parameters['type'] == 'str':
str_val = self.parameters["value"]
return f'{output} = "{str_val}"'
elif self.parameters['type'] == 'Device':
value = self.parameters['value']
return f'{output} = torch.device("{value}")'
elif self.parameters['type'] in ('dict', 'list', 'tuple'):
# TODO: prim::TupleIndex is not supported yet
return f'{output} = {repr(self.parameters["value"])}'
else:
raise RuntimeError(f'unsupported type of prim::Constant: {self.parameters["type"]}')
class PrimListConstruct(PyTorchOperation):
_ori_type_name = ['prim::ListConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = [{", ".join(inputs)}]'
class PrimListUnpack(PyTorchOperation):
_ori_type_name = ['prim::ListUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}'
class PrimTupleConstruct(PyTorchOperation):
_ori_type_name = ['prim::TupleConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = ({", ".join(inputs)})'
class PrimTupleUnpack(PyTorchOperation):
_ori_type_name = ['prim::TupleUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# have single output here, because the following code uses index to access the unpacked values
assert len(inputs) == 1
return f'{output} = {inputs[0]}'
class PrimGetAttr(PyTorchOperation):
_ori_type_name = ['prim::GetAttr']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.parameters['value'] is not None:
return f"{output} = {self.parameters['value']}"
else:
return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
class PrimUncheckedCast(PyTorchOperation):
_ori_type_name = ['prim::unchecked_cast']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}'
class SimpleMember(PyTorchOperation):
_ori_type_name = ['prim::is_cuda', 'prim::data']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
member_name = self.type.split('::')[-1]
return f'{output} = {inputs[0]}.{member_name}'
class AtenContiguous(PyTorchOperation):
_ori_type_name = ['aten::contiguous']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# defined in pytorch/c10/core/MemoryFormat.h
assert inputs_value is not None and inputs_value[1] in [0, 1, 2]
return f'{output} = {inputs[0]}.contiguous(memory_format={mem_format[inputs_value[1]]})'
class AtenGetitem(PyTorchOperation):
_ori_type_name = ['aten::__getitem__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'{output} = {inputs[0]}[{inputs[1]}]'
class AtenAppend(PyTorchOperation):
_ori_type_name = ['aten::append']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
class MergedSlice(PyTorchOperation):
_ori_type_name = ['MergedSlice']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if (len(inputs) - 1) % 4 == 0:
slices = []
dim = int((len(inputs) - 1) / 4)
for i in range(dim):
slices.append(f'{inputs[i*4+2]}:{inputs[i*4+3]}:{inputs[i*4+4]}')
slice_str = ','.join(slices)
return f'{output} = {inputs[0]}[{slice_str}]'
elif len(inputs) == 4:
# this case is for simple list
return f'{output} = {inputs[0]}[{inputs[1]}:{inputs[2]}:{inputs[3]}]'
else:
raise RuntimeError('Unsupported slice pattern')
# the following Aten classes means these aten ops are not in torch.Tensor
class AtenBool(PyTorchOperation):
_ori_type_name = ['aten::Bool']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = bool({inputs[0]})'
class AtenNot(PyTorchOperation):
_ori_type_name = ['aten::__not__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = not {inputs[0]}'
class AtenCat(PyTorchOperation):
_ori_type_name = ['aten::cat']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
# ====================================
class AtenTensors(PyTorchOperation):
_ori_type_name = ['aten::full', 'aten::full_like', 'aten::empty_like',
'aten::ones_like', 'aten::zeros_like', 'aten::rand',
'aten::randn', 'aten::scalar_tensor', 'aten::new_full',
'aten::new_empty', 'aten::new_zeros', 'aten::arange',
'aten::tensor', 'aten::ones', 'aten::zeros', 'aten::as_tensor']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
schemas = torch._C._jit_get_schemas_for_operator(self.type)
# match number of inputs
overloaded_defs = [len(s.arguments) for s in schemas]
matched = overloaded_defs.index(len(inputs))
args_list = []
for idx, arg in enumerate(schemas[matched].arguments):
if arg.name == 'dtype':
arg_str = f'dtype={scalar_type_to_pytorch_type[inputs_value[idx]]}' if inputs_value[idx] is not None else ''
elif arg.name == 'layout':
if inputs_value[idx] is not None:
arg_str = f'layout=torch.strided'
print('Warning: only support `torch.strided` for now!!!')
else:
arg_str = ''
elif arg.name == 'device':
arg_str = f'device=torch.device({inputs[idx]})' if inputs_value[idx] is not None else ''
elif arg.name == 'memory_format':
arg_str = f'memory_format={mem_format[inputs_value[idx]]}' if inputs_value[idx] is not None else ''
elif arg.name == 'pin_memory':
# TODO: deal with this argument
continue
elif arg.name == 'requires_grad':
arg_str = f'requires_grad={inputs[idx]}' if inputs_value[idx] else ''
elif str(arg.type).startswith('Optional['):
arg_str = f'{arg.name}={inputs[idx]}'
else:
arg_str = f'{inputs[idx]}'
if arg_str != '':
args_list.append(arg_str)
op_name = self.type.split('::')[-1]
if hasattr(torch, op_name):
return f'{output} = torch.{op_name}({", ".join(args_list)})'
else:
return f'{output} = {inputs[0]}.{op_name}({", ".join(args_list[1:])})'
# ====================================
class AtenFloordiv(PyTorchOperation):
_ori_type_name = ['aten::floordiv']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]} // {inputs[1]}'
class AtenMul(PyTorchOperation):
_ori_type_name = ['aten::mul']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]} * {inputs[1]}'
class AtenLen(PyTorchOperation):
_ori_type_name = ['aten::len']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = len({inputs[0]})'
class AtenIntImplicit(PyTorchOperation):
_ori_type_name = ['aten::IntImplicit', 'aten::Float', 'aten::Int', 'aten::ScalarImplicit']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.type.endswith('Implicit'):
return f'{output} = {inputs[0]}'
elif self.type == 'aten::Int':
return f'{output} = int({inputs[0]})'
elif self.type == 'aten::Float':
return f'{output} = float({inputs[0]})'
raise TypeError(f'Unexpected type: {self.type}')
class AtenIndex(PyTorchOperation):
_ori_type_name = ['aten::index']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}[{inputs[1]}]'
ManuallyChooseDef = {
'aten::flatten': [('start_dim', 'int', '0'), ('end_dim', 'int', '-1')],
'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')],
# in v1.9 dtype is supported as input argument for view, but torch script does not support it
'aten::view': [('size', 'List[int]', 'None')],
# NOTE: dim supports different types: List[int], List[str], Optional[List[int]], now we only support the first two, refactor needed
# torch.std(input, dim, unbiased, keepdim=False, *, out=None) Tensor
# torch.std(input, unbiased) Tensor
'aten::std': [('dim', 'List[int]', 'None'), ('unbiased', 'bool', 'True'), ('keepdim', 'bool', 'False')]
}
TensorOpExceptions = {
'aten::sub': lambda output, inputs: f'{output} = {inputs[0]} - {inputs[1]}', # example: x.size(1) - 3
'aten::add': lambda output, inputs: f'{output} = {inputs[0]} + {inputs[1]}' # example: input.shape[0] + 5
}
TorchOpExclude = ['aten::Size', 'aten::as_tensor', 'aten::device',
'aten::manual_seed', 'aten::quantized_gru', 'aten::quantized_lstm',
'aten::save', 'aten::tensor', 'aten::wait'
]
def _hidden(name):
return name.startswith('_') and not name.startswith('__')
def _emit_args(args):
# filter out the `out` argument here
return [(arg.name, str(arg.type), str(arg.default_value)) for arg in args] # if arg.name != 'out'
def _get_tensor_ops():
def is_tensor_method(schema):
if len(schema.arguments) == 0:
return False
self = schema.arguments[0]
if self.name != 'self':
return False
if not self.type.isSubtypeOf(torch._C.TensorType.get()):
return False
return True
op_args = {}
# discover methods
for elem in dir(torch.Tensor):
if not _hidden(elem):
schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem)
for schema in schemas:
if is_tensor_method(schema):
op_name = 'aten::' + elem
args = _emit_args(schema.arguments[1:])
if op_name in op_args:
op_args[op_name].append(args)
else:
op_args[op_name] = [args]
return op_args.keys(), op_args
def _get_torch_ops():
torch_op_args = {}
for mod in torch.jit._builtins._modules_containing_builtins: # type: ignore
name = mod.__name__
if name == 'torch._C._nn':
continue
# only process 'torch.XXX'
for elem in dir(mod):
builtin = torch.jit._builtins._find_builtin(getattr(mod, elem)) # type: ignore
if builtin is not None:
schemas = torch._C._jit_get_schemas_for_operator(builtin)
for schema in schemas:
# remove _tan but not __and__
if not _hidden(elem):
op_name = 'aten::' + elem
if len(schema.arguments) > 0 and schema.arguments[0].name == 'self':
continue
args = _emit_args(schema.arguments)
if op_name in torch_op_args:
torch_op_args[op_name].append(args)
else:
torch_op_args[op_name] = [args]
return torch_op_args.keys(), torch_op_args
def _get_torch_ops_exclude_tensor_ops():
tensor_op_names, _ = _get_tensor_ops()
torch_op_names, torch_ops = _get_torch_ops()
torch_exclude_ops = {}
for name in torch_op_names:
if name not in tensor_op_names:
if name not in TorchOpExclude:
# exclude the ops that are not in
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
torch_exclude_ops[name] = torch_ops[name]
return torch_exclude_ops.keys(), torch_exclude_ops
class TensorOps(PyTorchOperation):
"""
corresponding to _get_tensor_ops in torch.jit.supported_ops
"""
_ori_type_name, _op_args = _get_tensor_ops()
comparison_ops = {'aten::eq': '==', 'aten::ne': '!=', 'aten::le': '<=', 'aten::ge': '>=', 'aten::lt': '<', 'aten::gt': '>'}
@staticmethod
def _get_matched_args(_type, inputs):
def has_same_arg_name(matched):
concated_names = []
for i, each in enumerate(matched):
name = ','.join([arg[0] for arg in each])
concated_names.append(name)
for i in range(len(concated_names) - 1):
if concated_names[i] != concated_names[i + 1]:
return False
return True
overloaded_defs = TensorOps._op_args[_type]
matched = []
for each in overloaded_defs:
# plus 1 because we skip the first argument when generating tensor op def
if len(each) + 1 == len(inputs):
matched.append(each)
if len(matched) == 1:
return matched[0]
elif len(matched) > 1:
# TODO: match with arg's type. manually choose for now
if has_same_arg_name(matched):
# return any one is okay
return matched[0]
elif _type in ManuallyChooseDef:
return ManuallyChooseDef[_type]
else:
raise RuntimeError(f'tensor op type {_type} has more than one matched: {matched}')
else:
if _type in TensorOpExceptions:
return None
raise RuntimeError(f'tensor op type {_type} has no matched')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# TODO: deal with conditional ops
if self.type in TensorOps.comparison_ops:
return f'{output} = ({inputs[0]} {TensorOps.comparison_ops[self.type]} {inputs[1]})'
matched_args = TensorOps._get_matched_args(self.type, inputs)
if matched_args is None:
return TensorOpExceptions[self.type](output, inputs)
op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i+1]}' for i, (name, t, default) in enumerate(matched_args)])
return f'{output} = {inputs[0]}.{op_name}({args_str})'
class TorchOps(PyTorchOperation):
"""
corresponding to _get_nn_functional_ops in torch.jit.supported_ops
"""
_ori_type_name, _op_args = _get_torch_ops_exclude_tensor_ops()
# add 'aten::pixel_shuffle'
_op_args['aten::pixel_shuffle'] = [[('input', 'Tensor', 'None'), ('upscale_factor', 'Optional[int]', 'None')]]
_ori_type_name = _op_args.keys()
@staticmethod
def _get_matched_args(_type, inputs):
def has_same_arg_name(matched):
concated_names = []
for i, each in enumerate(matched):
name = ','.join([arg[0] for arg in each])
concated_names.append(name)
for i in range(len(concated_names) - 1):
if concated_names[i] != concated_names[i + 1]:
return False
return True
overloaded_defs = TorchOps._op_args[_type]
matched = []
for each in overloaded_defs:
if len(each) == len(inputs):
matched.append(each)
if len(matched) == 1:
return matched[0]
elif len(matched) > 1:
# TODO: match with arg's type. manually choose for now
if has_same_arg_name(matched):
# return any one is okay
return matched[0]
else:
raise RuntimeError(f'torch op type {_type} has more than one matched: {matched}')
else:
raise RuntimeError(f'torch op type {_type} has no matched')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
matched_args = TorchOps._get_matched_args(self.type, inputs)
op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i]}' if t.startswith('Optional[') else f'{inputs[i]}'
for i, (name, t, default) in enumerate(matched_args)])
return f'{output} = torch.{op_name}({args_str})'
class AtenAvgpool2d(PyTorchOperation):
# NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name = ['aten::avg_pool2d']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = F.avg_pool2d({", ".join(inputs)})'
class ToDevice(PyTorchOperation):
_artificial_op_name = "ToDevice"
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False,
attributes: Dict[str, Any] = {}):
self.type = "ToDevice"
self.device = parameters['device']
self.overridden_device_repr = None
self.src = parameters['src']
self.dst = parameters['dst']
def override_device_repr(self, device_repr):
# CUDA GPUDevice may remap GPU physical ID to CUDA ID. The device repr is different from GPUDevice.device_repr()
# override_device_repr will be called in pytorch.graph_to_pytorch_model to replace device_repr with the correct
# CUDA ID, e.g., when a job uses Physical GPU-1,2, its CUDA ID should be "cuda:0" and "cuda:1".
# self.device.device_repr() would return "cuda:1" and "cuda:2", but override_device_repr should be "cuda:0" and
# "cuda:1"
self.overridden_device_repr = device_repr
def __repr__(self):
if self.overridden_device_repr is None:
return f'to("{self.device.device_repr()}")'
else:
return f'to("{self.overridden_device_repr}")'
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.overridden_device_repr is None:
forward_code = f'{output} = {inputs[0]}.to("{self.device.device_repr()}")'
else:
forward_code = f'{output} = {inputs[0]}.to("{self.overridden_device_repr}")'
return forward_code
class AtenDet(PyTorchOperation):
# for torch 1.9
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
_ori_type_name = ['aten::linalg_det']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = torch.det({inputs[0]})'
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict, Any, Type, cast
import torch.nn as nn
from nni.nas.execution.common import (
Model, receive_trial_parameters,
get_mutation_dict, mutation_dict_to_summary
)
from nni.nas.evaluator import Evaluator
from nni.nas.utils import ContextStack
from .graph import BaseExecutionEngine
class PythonGraphData:
def __init__(self, class_: Type[nn.Module], init_parameters: Dict[str, Any],
mutation: Dict[str, Any], evaluator: Evaluator) -> None:
self.class_ = class_
self.init_parameters = init_parameters
self.mutation = mutation
self.evaluator = evaluator
self.mutation_summary = mutation_dict_to_summary(mutation)
def dump(self) -> dict:
return {
'class': self.class_,
'init_parameters': self.init_parameters,
'mutation': self.mutation,
# engine needs to call dump here,
# otherwise, evaluator will become binary
# also, evaluator can be none in tests
'evaluator': self.evaluator._dump() if self.evaluator is not None else None,
'mutation_summary': self.mutation_summary
}
@staticmethod
def load(data) -> 'PythonGraphData':
return PythonGraphData(data['class'], data['init_parameters'], data['mutation'], Evaluator._load(data['evaluator']))
class PurePythonExecutionEngine(BaseExecutionEngine):
"""
This is the execution engine that doesn't rely on Python-IR converter.
We didn't explicitly state this independency for now. Front-end needs to decide which converter / no converter
to use depending on the execution type. In the future, that logic may be moved into this execution engine.
The execution engine needs to store the class path of base model, and init parameters to re-initialize the model
with the mutation dict in the context, so that the mutable modules are created to be the fixed instance on the fly.
"""
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation = get_mutation_dict(model)
assert model.evaluator is not None, 'Model evaluator is not available.'
graph_data = PythonGraphData(
cast(Type[nn.Module], model.python_class),
model.python_init_params or {}, mutation, model.evaluator
)
return graph_data
@classmethod
def trial_execute_graph(cls) -> None:
graph_data = PythonGraphData.load(receive_trial_parameters())
def _model():
return graph_data.class_(**graph_data.init_parameters)
with ContextStack('fixed', graph_data.mutation):
graph_data.evaluator._execute(_model)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.nas.execution.common import TensorFlowOperation
class Conv2D(TensorFlowOperation):
def __init__(self, type_name, parameters, _internal, attributes=None):
if 'padding' not in parameters:
parameters['padding'] = 'same'
super().__init__(type_name, parameters, _internal)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Entrypoint for trials.
"""
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('exec', choices=['base', 'py', 'cgo', 'benchmark'])
args = parser.parse_args()
if args.exec == 'base':
from .pytorch.graph import BaseExecutionEngine
engine = BaseExecutionEngine
elif args.exec == 'cgo':
from .pytorch.cgo import CGOExecutionEngine
engine = CGOExecutionEngine
elif args.exec == 'py':
from .pytorch.simplified import PurePythonExecutionEngine
engine = PurePythonExecutionEngine
elif args.exec == 'benchmark':
from .pytorch.benchmark import BenchmarkExecutionEngine
engine = BenchmarkExecutionEngine
else:
raise ValueError(f'Unrecognized benchmark name: {args.exec}')
engine.trial_execute_graph()
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .mutator import DartsMutator from nni.common.framework import shortcut_framework
from .trainer import DartsTrainer
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .mutator import get_and_apply_next_architecture from .experiment_config import *
from .engine_config import *
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Optional, List
from nni.experiment.config.base import ConfigBase
__all__ = ['ExecutionEngineConfig', 'BaseEngineConfig', 'OneshotEngineConfig',
'PyEngineConfig', 'CgoEngineConfig', 'BenchmarkEngineConfig']
@dataclass(init=False)
class ExecutionEngineConfig(ConfigBase):
name: str
@dataclass(init=False)
class PyEngineConfig(ExecutionEngineConfig):
name: str = 'py'
@dataclass(init=False)
class OneshotEngineConfig(ExecutionEngineConfig):
name: str = 'oneshot'
@dataclass(init=False)
class BaseEngineConfig(ExecutionEngineConfig):
name: str = 'base'
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
@dataclass(init=False)
class CgoEngineConfig(ExecutionEngineConfig):
name: str = 'cgo'
max_concurrency_cgo: Optional[int] = None
batch_waiting_time: Optional[int] = None
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
@dataclass(init=False)
class BenchmarkEngineConfig(ExecutionEngineConfig):
name: str = 'benchmark'
benchmark: Optional[str] = None
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import sys
from dataclasses import dataclass
from typing import Any, Dict, Union, Optional
from nni.experiment.config import utils, ExperimentConfig
from .engine_config import ExecutionEngineConfig
__all__ = ['RetiariiExeConfig']
def execution_engine_config_factory(engine_name):
# FIXME: may move this function to experiment utils in future
cls = _get_ee_config_class(engine_name)
if cls is None:
raise ValueError(f'Invalid execution engine name: {engine_name}')
return cls()
def _get_ee_config_class(engine_name):
for cls in ExecutionEngineConfig.__subclasses__():
if cls.name == engine_name:
return cls
return None
@dataclass(init=False)
class RetiariiExeConfig(ExperimentConfig):
# FIXME: refactor this class to inherit from a new common base class with HPO config
search_space: Any = ''
trial_code_directory: utils.PathLike = '.'
trial_command: str = '_reserved'
# new config field for NAS
execution_engine: Union[str, ExecutionEngineConfig]
# Internal: to support customized fields in trial command
# Useful when customized python / environment variables are needed
_trial_command_params: Optional[Dict[str, Any]] = None
def __init__(self, training_service_platform: Union[str, None] = None,
execution_engine: Union[str, ExecutionEngineConfig] = 'py',
**kwargs):
super().__init__(training_service_platform, **kwargs)
self.execution_engine = execution_engine
def _canonicalize(self, _parents):
msg = '{} is not supposed to be set in Retiarii experiment by users, your config is {}.'
if self.search_space != '':
raise ValueError(msg.format('search_space', self.search_space))
# TODO: maybe we should also allow users to specify trial_code_directory
if str(self.trial_code_directory) != '.' and not os.path.isabs(self.trial_code_directory):
raise ValueError(msg.format('trial_code_directory', self.trial_code_directory))
trial_command_tmpl = '{envs} {python} -m nni.retiarii.trial_entry {execution_engine}'
if self.trial_command != '_reserved' and '-m nni.retiarii.trial_entry' not in self.trial_command:
raise ValueError(msg.format('trial_command', self.trial_command))
if isinstance(self.execution_engine, str):
self.execution_engine = execution_engine_config_factory(self.execution_engine)
_trial_command_params = {
# Default variables
'envs': '',
# TODO: maybe use sys.executable rendered in trial side (e.g., trial_runner)
'python': sys.executable,
'execution_engine': self.execution_engine.name,
# This should override the parameters above.
**(self._trial_command_params or {})
}
self.trial_command = trial_command_tmpl.format(**_trial_command_params).strip()
super()._canonicalize([self])
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
__all__ = ['RetiariiExeConfig', 'RetiariiExperiment', 'preprocess_model', 'debug_mutated_model']
import logging
import warnings
from threading import Thread
from typing import Any, List, cast
import colorama
import torch
import torch.nn as nn
from nni.experiment import Experiment, RunMode
from nni.experiment.config.training_services import RemoteConfig
from nni.nas.execution import list_models, set_execution_engine
from nni.nas.execution.common import RetiariiAdvisor, get_mutation_dict
from nni.nas.execution.pytorch.codegen import model_to_pytorch_script
from nni.nas.execution.pytorch.converter import convert_to_graph
from nni.nas.execution.pytorch.converter.graph_gen import GraphConverterWithShape
from nni.nas.evaluator import Evaluator
from nni.nas.mutable import Mutator
from nni.nas.nn.pytorch.mutator import (
extract_mutation_from_pt_module, process_inline_mutation, process_evaluator_mutations, process_oneshot_mutations
)
from nni.nas.utils import is_model_wrapped
from nni.nas.strategy import BaseStrategy
from nni.nas.strategy.utils import dry_run_for_formatted_search_space
from .config import (
RetiariiExeConfig, OneshotEngineConfig, BaseEngineConfig,
PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig
)
_logger = logging.getLogger(__name__)
def preprocess_model(base_model, evaluator, applied_mutators, full_ir=True, dummy_input=None, oneshot=False):
# TODO: this logic might need to be refactored into execution engine
if oneshot:
base_model_ir, mutators = process_oneshot_mutations(base_model, evaluator)
elif full_ir:
try:
script_module = torch.jit.script(base_model)
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
if dummy_input is not None:
# FIXME: this is a workaround as full tensor is not supported in configs
dummy_input = torch.randn(*dummy_input)
converter = GraphConverterWithShape()
base_model_ir = convert_to_graph(script_module, base_model, converter, dummy_input=dummy_input)
else:
base_model_ir = convert_to_graph(script_module, base_model)
# handle inline mutations
mutators = process_inline_mutation(base_model_ir)
else:
base_model_ir, mutators = extract_mutation_from_pt_module(base_model)
base_model_ir.evaluator = evaluator
if mutators is not None and applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
applied_mutators = mutators
# Add mutations on evaluators
applied_mutators += process_evaluator_mutations(evaluator, applied_mutators)
return base_model_ir, applied_mutators
def debug_mutated_model(base_model, evaluator, applied_mutators):
"""
Locally run only one trial without launching an experiment for debug purpose, then exit.
For example, it can be used to quickly check shape mismatch.
Specifically, it applies mutators (default to choose the first candidate for the choices)
to generate a new model, then run this model locally.
The model will be parsed with graph execution engine.
Parameters
----------
base_model : nni.retiarii.nn.pytorch.nn.Module
the base model
evaluator : nni.retiarii.graph.Evaluator
the training class of the generated models
applied_mutators : list
a list of mutators that will be applied on the base model for generating a new model
"""
base_model_ir, applied_mutators = preprocess_model(base_model, evaluator, applied_mutators)
from nni.nas.strategy.debug import _LocalDebugStrategy
strategy = _LocalDebugStrategy()
strategy.run(base_model_ir, applied_mutators)
_logger.info('local debug completed!')
class RetiariiExperiment(Experiment):
"""
The entry for a NAS experiment.
Users can use this class to start/stop or inspect an experiment, like exporting the results.
Experiment is a sub-class of :class:`nni.experiment.Experiment`, there are many similarities such as
configurable training service to distributed running the experiment on remote server.
But unlike :class:`nni.experiment.Experiment`, RetiariiExperiment doesn't support configure:
- ``trial_code_directory``, which can only be current working directory.
- ``search_space``, which is auto-generated in NAS.
- ``trial_command``, which must be ``python -m nni.retiarii.trial_entry`` to launch the modulized trial code.
RetiariiExperiment also doesn't have tuner/assessor/advisor, because they are also implemented in strategy.
Also, unlike :class:`nni.experiment.Experiment` which is bounded to a node server,
RetiariiExperiment optionally starts a node server to schedule the trials, when the strategy is a multi-trial strategy.
When the strategy is one-shot, the step of launching node server is omitted, and the experiment is run locally by default.
Configurations of experiments, such as execution engine, number of GPUs allocated,
should be put into a :class:`RetiariiExeConfig` and used as an argument of :meth:`RetiariiExperiment.run`.
Parameters
----------
base_model : nn.Module
The model defining the search space / base skeleton without mutation.
It should be wrapped by decorator ``nni.retiarii.model_wrapper``.
evaluator : nni.retiarii.Evaluator, default = None
Evaluator for the experiment.
If you are using a one-shot trainer, it should be placed here, although this usage is deprecated.
applied_mutators : list of nni.retiarii.Mutator, default = None
Mutators os mutate the base model. If none, mutators are skipped.
Note that when ``base_model`` uses inline mutations (e.g., LayerChoice), ``applied_mutators`` must be empty / none.
strategy : nni.retiarii.strategy.BaseStrategy, default = None
Exploration strategy. Can be multi-trial or one-shot.
trainer : BaseOneShotTrainer
Kept for compatibility purposes.
Examples
--------
Multi-trial NAS:
>>> base_model = Net()
>>> search_strategy = strategy.Random()
>>> model_evaluator = FunctionalEvaluator(evaluate_model)
>>> exp = RetiariiExperiment(base_model, model_evaluator, [], search_strategy)
>>> exp_config = RetiariiExeConfig('local')
>>> exp_config.trial_concurrency = 2
>>> exp_config.max_trial_number = 20
>>> exp_config.training_service.use_active_gpu = False
>>> exp.run(exp_config, 8081)
One-shot NAS:
>>> base_model = Net()
>>> search_strategy = strategy.DARTS()
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
>>> exp = RetiariiExperiment(base_model, evaluator, [], search_strategy)
>>> exp_config = RetiariiExeConfig()
>>> exp_config.execution_engine = 'oneshot' # must be set of one-shot strategy
>>> exp.run(exp_config)
Export top models:
>>> for model_dict in exp.export_top_models(formatter='dict'):
... print(model_dict)
>>> with nni.retarii.fixed_arch(model_dict):
... final_model = Net()
"""
def __init__(self, base_model: nn.Module,
evaluator: Evaluator = cast(Evaluator, None),
applied_mutators: List[Mutator] = cast(List[Mutator], None),
strategy: BaseStrategy = cast(BaseStrategy, None),
trainer: Any = None):
super().__init__(None)
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
if trainer is not None:
warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
evaluator = trainer
if evaluator is None:
raise ValueError('Evaluator should not be none.')
self.base_model = base_model
self.evaluator: Evaluator = evaluator
self.applied_mutators = applied_mutators
self.strategy = strategy
self._dispatcher = None
self._dispatcher_thread = None
# check for sanity
if not is_model_wrapped(base_model):
warnings.warn(colorama.Style.BRIGHT + colorama.Fore.RED +
'`@model_wrapper` is missing for the base model. The experiment might still be able to run, '
'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL,
RuntimeWarning)
def _run_strategy(self, config: RetiariiExeConfig):
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.evaluator, self.applied_mutators,
full_ir=not isinstance(config.execution_engine, (PyEngineConfig, BenchmarkEngineConfig)),
dummy_input=config.execution_engine.dummy_input
if isinstance(config.execution_engine, (BaseEngineConfig, CgoEngineConfig)) else None
)
_logger.info('Start strategy...')
search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators)
self.update_search_space(search_space)
self.strategy.run(base_model_ir, self.applied_mutators)
_logger.info('Strategy exit')
# TODO: find out a proper way to show no more trial message on WebUI
def _create_execution_engine(self, config: RetiariiExeConfig) -> None:
#TODO: we will probably need a execution engine factory to make this clean and elegant
if isinstance(config.execution_engine, BaseEngineConfig):
from nni.nas.execution.pytorch.graph import BaseExecutionEngine
engine = BaseExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, CgoEngineConfig):
from nni.nas.execution.pytorch.cgo import CGOExecutionEngine
assert not isinstance(config.training_service, list) \
and config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service"
assert config.execution_engine.batch_waiting_time is not None \
and config.execution_engine.max_concurrency_cgo is not None
engine = CGOExecutionEngine(cast(RemoteConfig, config.training_service),
max_concurrency=config.execution_engine.max_concurrency_cgo,
batch_waiting_time=config.execution_engine.batch_waiting_time,
rest_port=self.port,
rest_url_prefix=self.url_prefix)
elif isinstance(config.execution_engine, PyEngineConfig):
from nni.nas.execution.pytorch.simplified import PurePythonExecutionEngine
engine = PurePythonExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, BenchmarkEngineConfig):
from nni.nas.execution.pytorch.benchmark import BenchmarkExecutionEngine
assert config.execution_engine.benchmark is not None, \
'"benchmark" must be set when benchmark execution engine is used.'
engine = BenchmarkExecutionEngine(config.execution_engine.benchmark)
else:
raise ValueError(f'Unsupported engine type: {config.execution_engine}')
set_execution_engine(engine)
def start(self, *args, **kwargs) -> None:
"""
By design, the only different between `start` and `run` is that `start` is asynchronous,
while `run` waits the experiment to complete. RetiariiExperiment always waits the experiment
to complete as strategy runs in foreground.
"""
raise NotImplementedError('RetiariiExperiment is not supposed to provide `start` method')
def run(self,
config: RetiariiExeConfig | None = None,
port: int = 8080,
debug: bool = False) -> None:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
from nni.retiarii.oneshot.interface import BaseOneShotTrainer
if isinstance(self.evaluator, BaseOneShotTrainer):
warnings.warn('You are using the old implementation of one-shot algos based on One-shot trainer. '
'We will try to convert this trainer to our new implementation to run the algorithm. '
'In case you want to stick to the old implementation, '
'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
self.evaluator.fit()
return
if config is None:
warnings.warn('config = None is deprecate in future. If you are running a one-shot experiment, '
'please consider creating a config and set execution engine to `oneshot`.', DeprecationWarning)
self.config = RetiariiExeConfig()
self.config.execution_engine = OneshotEngineConfig()
else:
self.config = config
if isinstance(self.config.execution_engine, OneshotEngineConfig) \
or (isinstance(self.config.execution_engine, str) and self.config.execution_engine == 'oneshot'):
# this is hacky, will be refactored when oneshot can run on training services
base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.evaluator, self.applied_mutators, oneshot=True)
self.strategy.run(base_model_ir, self.applied_mutators)
else:
ws_url = f'ws://localhost:{port}/tuner'
canonicalized_config = self._start_impl(port, debug, RunMode.Background, ws_url, ['retiarii'])
canonicalized_config = cast(RetiariiExeConfig, canonicalized_config)
self._dispatcher = RetiariiAdvisor(ws_url)
self._dispatcher_thread = Thread(target=self._dispatcher.run, daemon=True)
self._dispatcher_thread.start()
# FIXME: engine cannot be created twice
self._create_execution_engine(canonicalized_config)
try:
self._run_strategy(canonicalized_config)
# FIXME: move this logic to strategy with a new API provided by execution engine
self._wait_completion()
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
self.stop()
_logger.info('Search process is done, the experiment is still alive, `stop()` can terminate the experiment.')
def stop(self) -> None:
"""
Stop background experiment.
"""
_logger.info('Stopping experiment, please wait...')
self._stop_impl()
if self._dispatcher_thread:
self._dispatcher_thread.join()
self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread = None
_logger.info('Experiment stopped')
def export_top_models(self, top_k: int = 1, optimize_mode: str = 'maximize', formatter: str = 'dict') -> Any:
"""
Export several top performing models.
For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` and ``formatter`` are
available for customization.
The concrete behavior of export depends on each strategy.
See the documentation of each strategy for detailed specifications.
Parameters
----------
top_k : int
How many models are intended to be exported.
optimize_mode : str
``maximize`` or ``minimize``. Not supported by one-shot algorithms.
``optimize_mode`` is likely to be removed and defined in strategy in future.
formatter : str
Support ``code`` and ``dict``. Not supported by one-shot algorithms.
If ``code``, the python code of model will be returned.
If ``dict``, the mutation history will be returned.
"""
# TODO: the base class may also need this method
if formatter == 'code':
config = self.config.canonical_copy()
assert not isinstance(config.execution_engine, PyEngineConfig), \
'You should use `dict` formatter when using Python execution engine.'
from nni.retiarii.oneshot.interface import BaseOneShotTrainer
if isinstance(self.evaluator, BaseOneShotTrainer):
assert top_k == 1, 'Only support top_k is 1 for now.'
return self.evaluator.export()
try:
# this currently works for one-shot algorithms
return self.strategy.export_top_models(top_k=top_k)
except NotImplementedError:
# when strategy hasn't implemented its own export logic
all_models = filter(lambda m: m.metric is not None, list_models())
assert optimize_mode in ['maximize', 'minimize']
all_models = sorted(all_models, key=lambda m: cast(float, m.metric), reverse=optimize_mode == 'maximize')
assert formatter in ['code', 'dict'], 'Export formatter other than "code" and "dict" is not supported yet.'
if formatter == 'code':
return [model_to_pytorch_script(model) for model in all_models[:top_k]]
elif formatter == 'dict':
return [get_mutation_dict(model) for model in all_models[:top_k]]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
from pathlib import Path
from typing import Union, Dict, Any
from .utils import ContextStack
_logger = logging.getLogger(__name__)
def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True):
"""
Load architecture from ``fixed_arch`` and apply to model. This should be used as a context manager. For example,
.. code-block:: python
with fixed_arch('/path/to/export.json'):
model = Model(3, 224, 224)
Parameters
----------
fixed_arc : str, Path or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns
-------
ContextStack
Context manager that provides a fixed architecture when creates the model.
"""
if isinstance(fixed_arch, (str, Path)):
with open(fixed_arch) as f:
fixed_arch = json.load(f)
if verbose:
_logger.info(f'Fixed architecture: %s', fixed_arch)
return ContextStack('fixed', fixed_arch)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mobilenetv3 import MobileNetV3Space
from .nasbench101 import NasBench101
from .nasbench201 import NasBench201
from .nasnet import NDS, NASNet, ENAS, AmoebaNet, PNAS, DARTS
from .proxylessnas import ProxylessNAS
from .shufflenet import ShuffleNetSpace
from .autoformer import AutoformerSpace
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional, Tuple, cast, Any, Dict
import torch
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper, basic_unit
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.nas.oneshot.pytorch.supermodule.operation import MixedOperation
from nni.nas.oneshot.pytorch.supermodule._valuechoice_utils import traverse_all_options
from nni.nas.oneshot.pytorch.supermodule._operation_utils import Slicable as _S, MaybeWeighted as _W
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
class RelativePosition2D(nn.Module):
def __init__(self, head_embed_dim, length=14,) -> None:
super().__init__()
self.head_embed_dim = head_embed_dim
self.legnth = length
self.embeddings_table_v = nn.Parameter(torch.randn(length * 2 + 2, head_embed_dim))
self.embeddings_table_h = nn.Parameter(torch.randn(length * 2 + 2, head_embed_dim))
trunc_normal_(self.embeddings_table_v, std=.02)
trunc_normal_(self.embeddings_table_h, std=.02)
def forward(self, length_q, length_k):
# remove the first cls token distance computation
length_q = length_q - 1
length_k = length_k - 1
# init in the device directly, rather than move to device
range_vec_q = torch.arange(length_q, device=self.embeddings_table_v.device)
range_vec_k = torch.arange(length_k, device=self.embeddings_table_v.device)
# compute the row and column distance
length_q_sqrt = int(length_q ** 0.5)
distance_mat_v = (range_vec_k[None, :] // length_q_sqrt - range_vec_q[:, None] // length_q_sqrt)
distance_mat_h = (range_vec_k[None, :] % length_q_sqrt - range_vec_q[:, None] % length_q_sqrt)
# clip the distance to the range of [-legnth, legnth]
distance_mat_clipped_v = torch.clamp(distance_mat_v, - self.legnth, self.legnth)
distance_mat_clipped_h = torch.clamp(distance_mat_h, - self.legnth, self.legnth)
# translate the distance from [1, 2 * legnth + 1], 0 is for the cls token
final_mat_v = distance_mat_clipped_v + self.legnth + 1
final_mat_h = distance_mat_clipped_h + self.legnth + 1
# pad the 0 which represent the cls token
final_mat_v = F.pad(final_mat_v, (1, 0, 1, 0), "constant", 0)
final_mat_h = F.pad(final_mat_h, (1, 0, 1, 0), "constant", 0)
final_mat_v = final_mat_v.long()
final_mat_h = final_mat_h.long()
# get the embeddings with the corresponding distance
embeddings = self.embeddings_table_v[final_mat_v] + self.embeddings_table_h[final_mat_h]
return embeddings
class RelativePositionAttention(nn.Module):
"""
This class is designed to support the relative position in attention.
The pytorch built-in nn.MultiheadAttention() does not support relative position embedding.
Different from the absolute position embedding, the relative position embedding considers
encode the relative distance between input tokens and learn the pairwise relations of them.
It is commonly calculated via a look-up table with learnable parameters interacting with queries
and keys in self-attention modules.
"""
def __init__(
self, embed_dim, num_heads,
attn_drop=0., proj_drop=0.,
qkv_bias=False, qk_scale=None,
rpe_length=14, rpe=False,
head_dim=64):
super().__init__()
self.num_heads = num_heads
# head_dim is fixed 64 in official autoformer. set head_dim = None to use flex head dim.
self.head_dim = head_dim or (embed_dim // num_heads)
self.scale = qk_scale or head_dim ** -0.5
# Please refer to MixedMultiheadAttention for details.
self.q = nn.Linear(embed_dim, head_dim * num_heads, bias = qkv_bias)
self.k = nn.Linear(embed_dim, head_dim * num_heads, bias = qkv_bias)
self.v = nn.Linear(embed_dim, head_dim * num_heads, bias = qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(head_dim * num_heads, embed_dim)
self.proj_drop = nn.Dropout(proj_drop)
self.rpe = rpe
if rpe:
self.rel_pos_embed_k = RelativePosition2D(head_dim, rpe_length)
self.rel_pos_embed_v = RelativePosition2D(head_dim, rpe_length)
def forward(self, x):
B, N, _ = x.shape
head_dim = self.head_dim
# num_heads can not get from self.num_heads directly,
# use -1 to compute implicitly.
num_heads = -1
q = self.q(x).reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
k = self.k(x).reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
v = self.v(x).reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
num_heads = q.size(1)
attn = (q @ k.transpose(-2, -1)) * self.scale
if self.rpe:
r_p_k = self.rel_pos_embed_k(N, N)
attn = attn + (
q.permute(2, 0, 1, 3).reshape(N, num_heads * B, head_dim) @ r_p_k.transpose(2, 1)
).transpose(1, 0).reshape(B, num_heads, N, N) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, num_heads * head_dim)
if self.rpe:
attn_1 = attn.permute(2, 0, 1, 3).reshape(N, B * num_heads, N)
r_p_v = self.rel_pos_embed_v(N, N)
# The size of attention is (B, num_heads, N, N), reshape it to (N, B*num_heads, N) and do batch matmul with
# the relative position embedding of V (N, N, head_dim) get shape like (N, B*num_heads, head_dim). We reshape it to the
# same size as x (B, num_heads, N, hidden_dim)
x = x + (attn_1 @ r_p_v).transpose(1, 0).reshape(B, num_heads, N, head_dim).transpose(2, 1).reshape(B, N, num_heads * head_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TransformerEncoderLayer(nn.Module):
"""
This class is designed to support the RelativePositionAttention().
The pytorch build-in nn.TransformerEncoderLayer() does not support customed attention.
"""
def __init__(
self, embed_dim, num_heads, mlp_ratio=4.,
qkv_bias=False, qk_scale=None, rpe=False,
drop_rate=0., attn_drop=0., proj_drop=0., drop_path=0.,
pre_norm=True, rpe_length=14, head_dim=64
):
super().__init__()
self.normalize_before = pre_norm
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.dropout = drop_rate
self.attn = RelativePositionAttention(
embed_dim=embed_dim,
num_heads=num_heads,
attn_drop=attn_drop,
proj_drop=proj_drop,
rpe=rpe,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
rpe_length=rpe_length,
head_dim=head_dim
)
self.attn_layer_norm = nn.LayerNorm(embed_dim)
self.ffn_layer_norm = nn.LayerNorm(embed_dim)
self.activation_fn = nn.GELU()
self.fc1 = nn.Linear(
cast(int, embed_dim),
cast(int, nn.ValueChoice.to_int(embed_dim * mlp_ratio))
)
self.fc2 = nn.Linear(
cast(int, nn.ValueChoice.to_int(embed_dim * mlp_ratio)),
cast(int, embed_dim)
)
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
def forward(self, x):
"""
Args:
x (Tensor): input to the layer of shape `(batch, patch_num , sample_embed_dim)`
Returns:
encoded output of shape `(batch, patch_num, sample_embed_dim)`
"""
residual = x
x = self.maybe_layer_norm(self.attn_layer_norm, x, before=True)
x = self.attn(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.drop_path(x)
x = residual + x
x = self.maybe_layer_norm(self.attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.ffn_layer_norm, x, before=True)
x = self.fc1(x)
x = self.activation_fn(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.drop_path(x)
x = residual + x
x = self.maybe_layer_norm(self.ffn_layer_norm, x, after=True)
return x
@basic_unit
class ClsToken(nn.Module):
""" Concat class token with dim=embed_dim before patch embedding.
"""
def __init__(self, embed_dim: int):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
trunc_normal_(self.cls_token, std=.02)
def forward(self, x):
return torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
class MixedClsToken(MixedOperation, ClsToken):
""" Mixed class token concat operation.
Supported arguments are:
- ``embed_dim``
Prefix of cls_token will be sliced.
"""
bound_type = ClsToken
argument_list = ['embed_dim']
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def forward_with_args(self, embed_dim,
inputs: torch.Tensor) -> torch.Tensor:
embed_dim_ = _W(embed_dim)
cls_token = _S(self.cls_token)[..., :embed_dim_]
return torch.cat((cls_token.expand(inputs.shape[0], -1, -1), inputs), dim=1)
@basic_unit
class AbsPosEmbed(nn.Module):
""" Add absolute position embedding on patch embedding.
"""
def __init__(self, length: int, embed_dim: int):
super().__init__()
self.pos_embed = nn.Parameter(torch.zeros(1, length, embed_dim))
trunc_normal_(self.pos_embed, std=.02)
def forward(self, x):
return x + self.pos_embed
class MixedAbsPosEmbed(MixedOperation, AbsPosEmbed):
""" Mixed absolute position embedding add operation.
Supported arguments are:
- ``embed_dim``
Prefix of pos_embed will be sliced.
"""
bound_type = AbsPosEmbed
argument_list = ['embed_dim']
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def forward_with_args(self, embed_dim,
inputs: torch.Tensor) -> torch.Tensor:
embed_dim_ = _W(embed_dim)
pos_embed = _S(self.pos_embed)[..., :embed_dim_]
return inputs + pos_embed
@model_wrapper
class AutoformerSpace(nn.Module):
"""
The search space that is proposed in `Autoformer <https://arxiv.org/abs/2107.00651>`__.
There are four searchable variables: depth, embedding dimension, heads number and MLP ratio.
Parameters
----------
search_embed_dim : list of int
The search space of embedding dimension.
search_mlp_ratio : list of float
The search space of MLP ratio.
search_num_heads : list of int
The search space of number of heads.
search_depth: list of int
The search space of depth.
img_size : int
Size of input image.
patch_size : int
Size of image patch.
in_chans : int
Number of channels of the input image.
num_classes : int
Number of classes for classifier.
qkv_bias : bool
Whether to use bias item in the qkv embedding.
drop_rate : float
Drop rate of the MLP projection in MSA and FFN.
attn_drop_rate : float
Drop rate of attention.
drop_path_rate : float
Drop path rate.
pre_norm : bool
Whether to use pre_norm. Otherwise post_norm is used.
global_pool : bool
Whether to use global pooling to generate the image representation. Otherwise the cls_token is used.
abs_pos : bool
Whether to use absolute positional embeddings.
qk_scale : float
The scaler on score map in self-attention.
rpe : bool
Whether to use relative position encoding.
"""
def __init__(
self,
search_embed_dim: Tuple[int, ...] = (192, 216, 240),
search_mlp_ratio: Tuple[float, ...] = (3.0, 3.5, 4.0),
search_num_heads: Tuple[int, ...] = (3, 4),
search_depth: Tuple[int, ...] = (12, 13, 14),
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
qkv_bias: bool = False,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
pre_norm: bool = True,
global_pool: bool = False,
abs_pos: bool = True,
qk_scale: Optional[float] = None,
rpe: bool = True,
):
super().__init__()
# define search space parameters
embed_dim = nn.ValueChoice(list(search_embed_dim), label="embed_dim")
depth = nn.ValueChoice(list(search_depth), label="depth")
mlp_ratios = [nn.ValueChoice(list(search_mlp_ratio), label=f"mlp_ratio_{i}") for i in range(max(search_depth))]
num_heads = [nn.ValueChoice(list(search_num_heads), label=f"num_head_{i}") for i in range(max(search_depth))]
self.patch_embed = nn.Conv2d(
in_chans, cast(int, embed_dim),
kernel_size = patch_size,
stride = patch_size
)
self.patches_num = int((img_size // patch_size) ** 2)
self.global_pool = global_pool
self.cls_token = ClsToken(cast(int, embed_dim))
self.pos_embed = AbsPosEmbed(self.patches_num+1, cast(int, embed_dim)) if abs_pos else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, max(search_depth))] # stochastic depth decay rule
self.blocks = nn.Repeat(
lambda index: TransformerEncoderLayer(
embed_dim = embed_dim, num_heads = num_heads[index], mlp_ratio=mlp_ratios[index],
qkv_bias = qkv_bias, drop_rate = drop_rate, attn_drop = attn_drop_rate, drop_path=dpr[index],
rpe_length=img_size // patch_size, qk_scale=qk_scale, rpe=rpe, pre_norm=pre_norm, head_dim = 64
), depth
)
self.norm = nn.LayerNorm(cast(int, embed_dim)) if pre_norm else nn.Identity()
self.head = nn.Linear(cast(int, embed_dim), num_classes) if num_classes > 0 else nn.Identity()
@classmethod
def get_extra_mutation_hooks(cls):
return [MixedAbsPosEmbed.mutate, MixedClsToken.mutate]
@classmethod
def load_searched_model(
cls, name: str,
pretrained: bool = False, download: bool = False, progress: bool = True
) -> nn.Module:
init_kwargs = {'qkv_bias': True, 'drop_rate': 0.0, 'drop_path_rate': 0.1, 'global_pool': True, 'num_classes': 1000}
if name == 'autoformer-tiny':
mlp_ratio = [3.5, 3.5, 3.0, 3.5, 3.0, 3.0, 4.0, 4.0, 3.5, 4.0, 3.5, 4.0, 3.5] + [3.0]
num_head = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3] + [3]
arch: Dict[str, Any] = {
'embed_dim': 192,
'depth': 13
}
for i in range(14):
arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
arch[f'num_head_{i}'] = num_head[i]
init_kwargs.update({
'search_embed_dim': (240, 216, 192),
'search_mlp_ratio': (4.0, 3.5, 3.0),
'search_num_heads': (4, 3),
'search_depth': (14, 13, 12),
})
elif name == 'autoformer-small':
mlp_ratio = [3.0, 3.5, 3.0, 3.5, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.5, 4.0] + [3.0]
num_head = [6, 6, 5, 7, 5, 5, 5, 6, 6, 7, 7, 6, 7] + [5]
arch: Dict[str, Any] = {
'embed_dim': 384,
'depth': 13
}
for i in range(14):
arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
arch[f'num_head_{i}'] = num_head[i]
init_kwargs.update({
'search_embed_dim': (448, 384, 320),
'search_mlp_ratio': (4.0, 3.5, 3.0),
'search_num_heads': (7, 6, 5),
'search_depth': (14, 13, 12),
})
elif name == 'autoformer-base':
mlp_ratio = [3.5, 3.5, 4.0, 3.5, 4.0, 3.5, 3.5, 3.0, 4.0, 4.0, 3.0, 4.0, 3.0, 3.5] + [3.0, 3.0]
num_head = [9, 9, 9, 9, 9, 10, 9, 9, 10, 9, 10, 9, 9, 10] + [8, 8]
arch: Dict[str, Any] = {
'embed_dim': 576,
'depth': 14
}
for i in range(16):
arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
arch[f'num_head_{i}'] = num_head[i]
init_kwargs.update({
'search_embed_dim': (624, 576, 528),
'search_mlp_ratio': (4.0, 3.5, 3.0),
'search_num_heads': (10, 9, 8),
'search_depth': (16, 15, 14),
})
else:
raise ValueError(f'Unsupported architecture with name: {name}')
model_factory = FixedFactory(cls, arch)
model = model_factory(**init_kwargs)
if pretrained:
weight_file = load_pretrained_weight(name, download=download, progress=progress)
pretrained_weights = torch.load(weight_file)
model.load_state_dict(pretrained_weights)
return model
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
x = x.permute(0, 2, 3, 1).view(B, self.patches_num, -1)
x = self.cls_token(x)
x = self.pos_embed(x)
x = self.blocks(x)
x = self.norm(x)
if self.global_pool:
x = torch.mean(x[:, 1:], dim=1)
else:
x = x[:, 0]
x = self.head(x)
return x
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from functools import partial
from typing import Tuple, Optional, Callable, Union, List, Type, cast
import torch
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper
from nni.typehint import Literal
from .proxylessnas import ConvBNReLU, InvertedResidual, DepthwiseSeparableConv, make_divisible, reset_parameters
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
class SqueezeExcite(nn.Module):
"""Squeeze-and-excite layer.
We can't use the op from ``torchvision.ops`` because it's not (yet) properly wrapped,
and ValueChoice couldn't be processed.
Reference:
- https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/efficientnet_blocks.py#L26
- https://github.com/d-li14/mobilenetv3.pytorch/blob/3e6938cedcbbc5ee5bc50780ea18e644702d85fc/mobilenetv3.py#L53
"""
def __init__(self,
channels: int,
reduction_ratio: float = 0.25,
gate_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None):
super().__init__()
rd_channels = make_divisible(channels * reduction_ratio, 8)
gate_layer = gate_layer or nn.Hardsigmoid
activation_layer = activation_layer or nn.ReLU
self.conv_reduce = nn.Conv2d(channels, rd_channels, 1, bias=True)
self.act1 = activation_layer(inplace=True)
self.conv_expand = nn.Conv2d(rd_channels, channels, 1, bias=True)
self.gate = gate_layer()
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
return x * self.gate(x_se)
def _se_or_skip(hidden_ch: int, input_ch: int, optional: bool, se_from_exp: bool, label: str) -> nn.Module:
ch = hidden_ch if se_from_exp else input_ch
if optional:
return nn.LayerChoice({
'identity': nn.Identity(),
'se': SqueezeExcite(ch)
}, label=label)
else:
return SqueezeExcite(ch)
def _act_fn(act_alias: Literal['hswish', 'swish', 'relu']) -> Type[nn.Module]:
if act_alias == 'hswish':
return nn.Hardswish
elif act_alias == 'swish':
return nn.SiLU
elif act_alias == 'relu':
return nn.ReLU
else:
raise ValueError(f'Unsupported act alias: {act_alias}')
@model_wrapper
class MobileNetV3Space(nn.Module):
"""
MobileNetV3Space implements the largest search space in `TuNAS <https://arxiv.org/abs/2008.06120>`__.
The search dimensions include widths, expand ratios, kernel sizes, SE ratio.
Some of them can be turned off via arguments to narrow down the search space.
Different from ProxylessNAS search space, this space is implemented with :class:`nn.ValueChoice`.
We use the following snipppet as reference.
https://github.com/google-research/google-research/blob/20736344591f774f4b1570af64624ed1e18d2867/tunas/mobile_search_space_v3.py#L728
We have ``num_blocks`` which equals to the length of ``self.blocks`` (the main body of the network).
For simplicity, the following parameter specification assumes ``num_blocks`` equals 8 (body + head).
If a shallower body is intended, arrays including ``base_widths``, ``squeeze_excite``, ``depth_range``,
``stride``, ``activation`` should also be shortened accordingly.
Parameters
----------
num_labels
Dimensions for classification head.
base_widths
Widths of each stage, from stem, to body, to head.
Length should be 9, i.e., ``num_blocks + 1`` (because there is a stem width in front).
width_multipliers
A range of widths multiplier to choose from. The choice is independent for each stage.
Or it can be a fixed float. This will be applied on ``base_widths``,
and we would also make sure that widths can be divided by 8.
expand_ratios
A list of expand ratios to choose from. Independent for every **block**.
squeeze_excite
Indicating whether the current stage can have an optional SE layer.
Expect array of length 6 for stage 0 to 5. Each element can be one of ``force``, ``optional``, ``none``.
depth_range
A range (e.g., ``(1, 4)``),
or a list of range (e.g., ``[(1, 3), (1, 4), (1, 4), (1, 3), (0, 2)]``).
If a list, the length should be 5. The depth are specified for stage 1 to 5.
stride
Stride for all stages (including stem and head). Length should be same as ``base_widths``.
activation
Activation (class) for all stages. Length is same as ``base_widths``.
se_from_exp
Calculate SE channel reduction from expanded (mid) channels.
dropout_rate
Dropout rate at classification head.
bn_eps
Epsilon of batch normalization.
bn_momentum
Momentum of batch normalization.
"""
widths: List[Union[nn.ChoiceOf[int], int]]
depth_range: List[Tuple[int, int]]
def __init__(
self, num_labels: int = 1000,
base_widths: Tuple[int, ...] = (16, 16, 16, 32, 64, 128, 256, 512, 1024),
width_multipliers: Union[Tuple[float, ...], float] = (0.5, 0.625, 0.75, 1.0, 1.25, 1.5, 2.0),
expand_ratios: Tuple[float, ...] = (1., 2., 3., 4., 5., 6.),
squeeze_excite: Tuple[Literal['force', 'optional', 'none'], ...] = (
'none', 'none', 'optional', 'none', 'optional', 'optional'
),
depth_range: Union[List[Tuple[int, int]], Tuple[int, int]] = (1, 4),
stride: Tuple[int, ...] = (2, 1, 2, 2, 2, 1, 2, 1, 1),
activation: Tuple[Literal['hswish', 'swish', 'relu'], ...] = (
'hswish', 'relu', 'relu', 'relu', 'hswish', 'hswish', 'hswish', 'hswish', 'hswish'
),
se_from_exp: bool = True,
dropout_rate: float = 0.2,
bn_eps: float = 1e-3,
bn_momentum: float = 0.1
):
super().__init__()
self.num_blocks = len(base_widths) - 1 # without stem, equal to len(self.blocks)
assert self.num_blocks >= 4
assert len(base_widths) == len(stride) == len(activation) == self.num_blocks + 1
# The final two blocks can't have SE
assert len(squeeze_excite) == self.num_blocks - 2 and all(se in ['force', 'optional', 'none'] for se in squeeze_excite)
# The first and final two blocks can't have variational depth
if isinstance(depth_range[0], int):
depth_range = cast(Tuple[int, int], depth_range)
assert len(depth_range) == 2 and depth_range[1] >= depth_range[0] >= 1
self.depth_range = [depth_range] * (self.num_blocks - 3)
else:
assert len(depth_range) == self.num_blocks - 3
self.depth_range = cast(List[Tuple[int, int]], depth_range)
for d in self.depth_range:
d = cast(Tuple[int, int], d)
# pylint: disable=unsubscriptable-object
assert len(d) == 2 and d[1] >= d[0] >= 1, f'{d} does not satisfy depth constraints'
self.widths = []
for i, base_width in enumerate(base_widths):
if isinstance(width_multipliers, float):
self.widths.append(make_divisible(base_width * width_multipliers, 8))
else:
self.widths.append(
# According to tunas, stem and stage 0 share one width multiplier
# https://github.com/google-research/google-research/blob/20736344/tunas/mobile_search_space_v3.py#L791
make_divisible(
nn.ValueChoice(list(width_multipliers), label=f's{max(i - 1, 0)}_width_mult') * base_width, 8
)
)
self.expand_ratios = expand_ratios
self.se_from_exp = se_from_exp
# NOTE: The built-in hardswish produces slightly different output from 3rd-party implementation
# But I guess it doesn't really matter.
# https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/layers/activations.py#L79
self.stem = ConvBNReLU(
3, self.widths[0],
nn.ValueChoice([3, 5], label=f'stem_ks'),
stride=stride[0], activation_layer=_act_fn(activation[0])
)
blocks: List[nn.Module] = [
# Stage 0
# FIXME: this should be an optional layer.
# https://github.com/google-research/google-research/blob/20736344/tunas/mobile_search_space_v3.py#L791
DepthwiseSeparableConv(
self.widths[0], self.widths[1],
nn.ValueChoice([3, 5, 7], label=f's0_i0_ks'),
stride=stride[1],
squeeze_excite=cast(Callable[[nn.MaybeChoice[int], nn.MaybeChoice[int]], nn.Module], partial(
_se_or_skip, optional=squeeze_excite[0] == 'optional', se_from_exp=self.se_from_exp, label=f's0_i0_se'
)) if squeeze_excite[0] != 'none' else None,
activation_layer=_act_fn(activation[1])
),
]
blocks += [
# Stage 1-5 (by default)
self._make_stage(i, self.widths[i], self.widths[i + 1], squeeze_excite[i], stride[i + 1], _act_fn(activation[i + 1]))
for i in range(1, self.num_blocks - 2)
]
# Head
blocks += [
ConvBNReLU(
self.widths[self.num_blocks - 2],
self.widths[self.num_blocks - 1],
kernel_size=1,
stride=stride[self.num_blocks - 1],
activation_layer=_act_fn(activation[self.num_blocks - 1])
),
nn.AdaptiveAvgPool2d(1),
# In some implementation, this is a linear instead.
# Should be equivalent.
ConvBNReLU(
self.widths[self.num_blocks - 1],
self.widths[self.num_blocks],
kernel_size=1,
stride=stride[self.num_blocks],
norm_layer=nn.Identity,
activation_layer=_act_fn(activation[self.num_blocks])
)
]
self.blocks = nn.Sequential(*blocks)
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(cast(int, self.widths[self.num_blocks]), num_labels),
)
reset_parameters(self, bn_momentum=bn_momentum, bn_eps=bn_eps)
def forward(self, x):
x = self.stem(x)
x = self.blocks(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def _make_stage(self, stage_idx, inp, oup, se, stride, act):
def layer_builder(idx):
exp = nn.ValueChoice(list(self.expand_ratios), label=f's{stage_idx}_i{idx}_exp')
ks = nn.ValueChoice([3, 5, 7], label=f's{stage_idx}_i{idx}_ks')
# if SE is true, assign a layer choice to SE
se_or_skip = cast(Callable[[nn.MaybeChoice[int], nn.MaybeChoice[int]], nn.Module], partial(
_se_or_skip, optional=se == 'optional', se_from_exp=self.se_from_exp, label=f's{stage_idx}_i{idx}_se'
)) if se != 'none' else None
return InvertedResidual(
inp if idx == 0 else oup,
oup, exp, ks,
stride=stride if idx == 0 else 1, # only the first layer in each stage can have stride > 1
squeeze_excite=se_or_skip,
activation_layer=act,
)
# mutable depth
min_depth, max_depth = self.depth_range[stage_idx - 1]
if stride != 1:
min_depth = max(min_depth, 1)
return nn.Repeat(layer_builder, depth=(min_depth, max_depth), label=f's{stage_idx}_depth')
@classmethod
def fixed_arch(cls, arch: dict) -> FixedFactory:
return FixedFactory(cls, arch)
@classmethod
def load_searched_model(
cls, name: str,
pretrained: bool = False, download: bool = False, progress: bool = True
) -> nn.Module:
init_kwargs = {} # all default
if name == 'mobilenetv3-large-100':
# NOTE: Use bicsubic interpolation to evaluate this
# With default interpolation, it yields top-1 75.722
arch = {
'stem_ks': 3,
's0_i0_ks': 3,
's1_depth': 2,
's1_i0_exp': 4,
's1_i0_ks': 3,
's1_i1_exp': 3,
's1_i1_ks': 3,
's2_depth': 3,
's2_i0_exp': 3,
's2_i0_ks': 5,
's2_i1_exp': 3,
's2_i1_ks': 5,
's2_i2_exp': 3,
's2_i2_ks': 5,
's3_depth': 4,
's3_i0_exp': 6,
's3_i0_ks': 3,
's3_i1_exp': 2.5,
's3_i1_ks': 3,
's3_i2_exp': 2.3,
's3_i2_ks': 3,
's3_i3_exp': 2.3,
's3_i3_ks': 3,
's4_depth': 2,
's4_i0_exp': 6,
's4_i0_ks': 3,
's4_i1_exp': 6,
's4_i1_ks': 3,
's5_depth': 3,
's5_i0_exp': 6,
's5_i0_ks': 5,
's5_i1_exp': 6,
's5_i1_ks': 5,
's5_i2_exp': 6,
's5_i2_ks': 5,
}
init_kwargs.update(
base_widths=[16, 16, 24, 40, 80, 112, 160, 960, 1280],
expand_ratios=[1.0, 2.0, 2.3, 2.5, 3.0, 4.0, 6.0],
bn_eps=1e-5,
bn_momentum=0.1,
width_multipliers=1.0,
squeeze_excite=['none', 'none', 'force', 'none', 'force', 'force']
)
elif name.startswith('mobilenetv3-small-'):
# Evaluate with bicubic interpolation
multiplier = int(name.split('-')[-1]) / 100
widths = [16, 16, 24, 40, 48, 96, 576, 1024]
for i in range(7):
if i > 0 or multiplier >= 0.75:
# fix_stem = True when multiplier < 0.75
# https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/mobilenetv3.py#L421
widths[i] = make_divisible(widths[i] * multiplier, 8)
init_kwargs.update(
base_widths=widths,
width_multipliers=1.0,
expand_ratios=[3.0, 3.67, 4.0, 4.5, 6.0],
bn_eps=1e-05,
bn_momentum=0.1,
squeeze_excite=['force', 'none', 'force', 'force', 'force'],
activation=['hswish', 'relu', 'relu', 'hswish', 'hswish', 'hswish', 'hswish', 'hswish'],
stride=[2, 2, 2, 2, 1, 2, 1, 1],
depth_range=(1, 2),
)
arch = {
'stem_ks': 3,
's0_i0_ks': 3,
's1_depth': 2,
's1_i0_exp': 4.5,
's1_i0_ks': 3,
's1_i1_exp': 3.67,
's1_i1_ks': 3,
's2_depth': 3,
's2_i0_exp': 4.0,
's2_i0_ks': 5,
's2_i1_exp': 6.0,
's2_i1_ks': 5,
's2_i2_exp': 6.0,
's2_i2_ks': 5,
's3_depth': 2,
's3_i0_exp': 3.0,
's3_i0_ks': 5,
's3_i1_exp': 3.0,
's3_i1_ks': 5,
's4_depth': 3,
's4_i0_exp': 6.0,
's4_i0_ks': 5,
's4_i1_exp': 6.0,
's4_i1_ks': 5,
's4_i2_exp': 6.0,
's4_i2_ks': 5
}
elif name.startswith('cream'):
# https://github.com/microsoft/Cream/tree/main/Cream
# bilinear interpolation
level = name.split('-')[-1]
# region cream arch specification
if level == '014':
arch = {
'stem_ks': 3,
's0_depth': 1,
's0_i0_ks': 3,
's1_depth': 1,
's1_i0_exp': 4.0,
's1_i0_ks': 3,
's2_depth': 2,
's2_i0_exp': 6.0,
's2_i0_ks': 5,
's2_i1_exp': 6.0,
's2_i1_ks': 5,
's3_depth': 2,
's3_i0_exp': 6.0,
's3_i0_ks': 5,
's3_i1_exp': 6.0,
's3_i1_ks': 5,
's4_depth': 1,
's4_i0_exp': 6.0,
's4_i0_ks': 3,
's5_depth': 1,
's5_i0_exp': 6.0,
's5_i0_ks': 5
}
elif level == '043':
arch = {
'stem_ks': 3,
's0_depth': 1,
's0_i0_ks': 3,
's1_depth': 1,
's1_i0_exp': 4.0,
's1_i0_ks': 3,
's2_depth': 2,
's2_i0_exp': 6.0,
's2_i0_ks': 5,
's2_i1_exp': 6.0,
's2_i1_ks': 3,
's3_depth': 2,
's3_i0_exp': 6.0,
's3_i0_ks': 5,
's3_i1_exp': 6.0,
's3_i1_ks': 3,
's4_depth': 3,
's4_i0_exp': 6.0,
's4_i0_ks': 5,
's4_i1_exp': 6.0,
's4_i1_ks': 5,
's4_i2_exp': 6.0,
's4_i2_ks': 5,
's5_depth': 2,
's5_i0_exp': 6.0,
's5_i0_ks': 5,
's5_i1_exp': 6.0,
's5_i1_ks': 5
}
elif level == '114':
arch = {
'stem_ks': 3,
's0_depth': 1,
's0_i0_ks': 3,
's1_depth': 1,
's1_i0_exp': 4.0,
's1_i0_ks': 3,
's2_depth': 2,
's2_i0_exp': 6.0,
's2_i0_ks': 5,
's2_i1_exp': 6.0,
's2_i1_ks': 5,
's3_depth': 2,
's3_i0_exp': 6.0,
's3_i0_ks': 5,
's3_i1_exp': 6.0,
's3_i1_ks': 5,
's4_depth': 3,
's4_i0_exp': 6.0,
's4_i0_ks': 5,
's4_i1_exp': 6.0,
's4_i1_ks': 5,
's4_i2_exp': 6.0,
's4_i2_ks': 5,
's5_depth': 2,
's5_i0_exp': 6.0,
's5_i0_ks': 5,
's5_i1_exp': 6.0,
's5_i1_ks': 5
}
elif level == '287':
arch = {
'stem_ks': 3,
's0_depth': 1,
's0_i0_ks': 3,
's1_depth': 1,
's1_i0_exp': 4.0,
's1_i0_ks': 3,
's2_depth': 2,
's2_i0_exp': 6.0,
's2_i0_ks': 5,
's2_i1_exp': 6.0,
's2_i1_ks': 5,
's3_depth': 3,
's3_i0_exp': 6.0,
's3_i0_ks': 5,
's3_i1_exp': 6.0,
's3_i1_ks': 3,
's3_i2_exp': 6.0,
's3_i2_ks': 5,
's4_depth': 4,
's4_i0_exp': 6.0,
's4_i0_ks': 5,
's4_i1_exp': 6.0,
's4_i1_ks': 5,
's4_i2_exp': 6.0,
's4_i2_ks': 5,
's4_i3_exp': 6.0,
's4_i3_ks': 5,
's5_depth': 3,
's5_i0_exp': 6.0,
's5_i0_ks': 5,
's5_i1_exp': 6.0,
's5_i1_ks': 5,
's5_i2_exp': 6.0,
's5_i2_ks': 5
}
elif level == '481':
arch = {
'stem_ks': 3,
's0_depth': 1,
's0_i0_ks': 3,
's1_depth': 4,
's1_i0_exp': 6.0,
's1_i0_ks': 5,
's1_i1_exp': 4.0,
's1_i1_ks': 7,
's1_i2_exp': 6.0,
's1_i2_ks': 5,
's1_i3_exp': 6.0,
's1_i3_ks': 3,
's2_depth': 4,
's2_i0_exp': 6.0,
's2_i0_ks': 5,
's2_i1_exp': 4.0,
's2_i1_ks': 5,
's2_i2_exp': 6.0,
's2_i2_ks': 5,
's2_i3_exp': 4.0,
's2_i3_ks': 3,
's3_depth': 5,
's3_i0_exp': 6.0,
's3_i0_ks': 5,
's3_i1_exp': 6.0,
's3_i1_ks': 5,
's3_i2_exp': 6.0,
's3_i2_ks': 5,
's3_i3_exp': 6.0,
's3_i3_ks': 3,
's3_i4_exp': 6.0,
's3_i4_ks': 3,
's4_depth': 4,
's4_i0_exp': 6.0,
's4_i0_ks': 5,
's4_i1_exp': 6.0,
's4_i1_ks': 5,
's4_i2_exp': 6.0,
's4_i2_ks': 5,
's4_i3_exp': 6.0,
's4_i3_ks': 5,
's5_depth': 4,
's5_i0_exp': 6.0,
's5_i0_ks': 5,
's5_i1_exp': 6.0,
's5_i1_ks': 5,
's5_i2_exp': 6.0,
's5_i2_ks': 5,
's5_i3_exp': 6.0,
's5_i3_ks': 5
}
elif level == '604':
arch = {
'stem_ks': 3,
's0_depth': 1,
's0_i0_ks': 3,
's1_depth': 5,
's1_i0_exp': 6.0,
's1_i0_ks': 5,
's1_i1_exp': 6.0,
's1_i1_ks': 5,
's1_i2_exp': 4.0,
's1_i2_ks': 5,
's1_i3_exp': 6.0,
's1_i3_ks': 5,
's1_i4_exp': 6.0,
's1_i4_ks': 5,
's2_depth': 5,
's2_i0_exp': 6.0,
's2_i0_ks': 5,
's2_i1_exp': 4.0,
's2_i1_ks': 5,
's2_i2_exp': 6.0,
's2_i2_ks': 5,
's2_i3_exp': 4.0,
's2_i3_ks': 5,
's2_i4_exp': 6.0,
's2_i4_ks': 5,
's3_depth': 5,
's3_i0_exp': 6.0,
's3_i0_ks': 5,
's3_i1_exp': 4.0,
's3_i1_ks': 5,
's3_i2_exp': 6.0,
's3_i2_ks': 5,
's3_i3_exp': 4.0,
's3_i3_ks': 5,
's3_i4_exp': 6.0,
's3_i4_ks': 5,
's4_depth': 6,
's4_i0_exp': 6.0,
's4_i0_ks': 5,
's4_i1_exp': 6.0,
's4_i1_ks': 5,
's4_i2_exp': 4.0,
's4_i2_ks': 5,
's4_i3_exp': 4.0,
's4_i3_ks': 5,
's4_i4_exp': 6.0,
's4_i4_ks': 5,
's4_i5_exp': 6.0,
's4_i5_ks': 5,
's5_depth': 6,
's5_i0_exp': 6.0,
's5_i0_ks': 5,
's5_i1_exp': 6.0,
's5_i1_ks': 5,
's5_i2_exp': 4.0,
's5_i2_ks': 5,
's5_i3_exp': 6.0,
's5_i3_ks': 5,
's5_i4_exp': 6.0,
's5_i4_ks': 5,
's5_i5_exp': 6.0,
's5_i5_ks': 5
}
else:
raise ValueError(f'Unsupported cream model level: {level}')
# endregion
init_kwargs.update(
base_widths=[16, 16, 24, 40, 80, 96, 192, 320, 1280],
width_multipliers=1.0,
expand_ratios=[4.0, 6.0],
bn_eps=1e-5,
bn_momentum=0.1,
squeeze_excite=['force'] * 6,
activation=['swish'] * 9
)
else:
raise ValueError(f'Unsupported architecture with name: {name}')
model_factory = cls.fixed_arch(arch)
model = model_factory(**init_kwargs)
if pretrained:
weight_file = load_pretrained_weight(name, download=download, progress=progress)
pretrained_weights = torch.load(weight_file)
model.load_state_dict(pretrained_weights)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment