Unverified Commit 18962129 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add license header and typehints for NAS (#4774)

parent 8c2f717d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools
from peewee import fn
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .model import NlpTrialStats, NlpIntermediateStats, NlpTrialConfig
from .query import query_nlp_trial_stats
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import os
import argparse
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools
from peewee import fn
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools
import hashlib
import json
......@@ -6,6 +9,7 @@ import os
import shutil
import tempfile
from pathlib import Path
from typing import Optional
import requests
import tqdm
......@@ -49,8 +53,9 @@ def load_or_download_file(local_path: str, download_url: str, download: bool = F
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
r = requests.get(download_url, stream=True)
total_length = int(r.headers.get('content-length'))
with tqdm.tqdm(total=total_length, disable=not progress,
total_length: Optional[str] = r.headers.get('content-length')
assert total_length is not None, f'Content length is not found in the response of {download_url}'
with tqdm.tqdm(total=int(total_length), disable=not progress,
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
for chunk in r.iter_content(8192):
f.write(chunk)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base_mutator import BaseMutator
from .base_trainer import BaseTrainer
from .fixed import apply_fixed_architecture
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .nasbench201 import NASBench201Cell
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections import OrderedDict
import torch.nn as nn
from nni.nas.pytorch.mutables import LayerChoice
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .darts_cell import DartsCell
from .enas_cell import ENASMicroLayer
from .enas_cell import ENASMacroLayer
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .pytorch import model_to_pytorch_script
......@@ -3,8 +3,9 @@
import logging
import re
from typing import Dict, List, Tuple, Any
from typing import Dict, List, Tuple, Any, cast
from nni.retiarii.operation import PyTorchOperation
from nni.retiarii.operation_def.torch_op_def import ToDevice
from nni.retiarii.utils import STATE_DICT_PY_MAPPING
from nni.common.device import Device, GPUDevice
......@@ -34,7 +35,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
if all(edge.tail_slot is None for edge in edges):
return edges
if all(isinstance(edge.tail_slot, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: edge.tail_slot))
edges = sorted(edges, key=(lambda edge: cast(int, edge.tail_slot)))
if [edge.tail_slot for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
......@@ -98,7 +99,7 @@ def _format_variable_name(name: str, graph_name: str) -> str:
name = name.replace('/', '__')
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
name = re.sub('\W|^(?=\d)','_', name)
name = re.sub(r'\W|^(?=\d)','_', name)
if name.startswith('__') and (len(name) > 2 and name[2] != '_'):
# name can't start with double underscore
......@@ -130,7 +131,7 @@ def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
return cuda_remapped_id
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str:
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> Tuple[set, str]:
nodes = graph.topo_sort()
# handle module node and function node differently
......@@ -144,11 +145,12 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
for node in nodes:
if node.operation:
if placement and isinstance(node.operation, ToDevice):
cuda_remapped_id = cast(dict, cuda_remapped_id)
node.operation.override_device_repr("cuda:%d" % cuda_remapped_id[node.operation.device])
if node.operation.type == 'shared':
continue
pkg_name = node.operation.get_import_pkg()
pkg_name = cast(PyTorchOperation, node.operation).get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
......@@ -157,6 +159,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
if isinstance(placement[node], GPUDevice):
assert cuda_remapped_id is not None
device_repr = "cuda:%d" % cuda_remapped_id[placement[node]]
else:
device_repr = placement[node].device_repr()
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
# pylint: skip-file
# type: ignore
"""
FIXME
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .graph_gen import convert_to_graph
......@@ -411,6 +411,7 @@ class GraphConverter:
edge.graph = ir_graph
if edge.head == method_ir_graph.input_node:
# this is a member method, 'self' is the first argument, thus +1
assert edge.head_slot is not None
_input = node.inputsAt(edge.head_slot + 1)
src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
edge.head = src_node
......@@ -745,6 +746,7 @@ class GraphConverterWithShape(GraphConverter):
if isinstance(submodule, LayerChoice):
full_name = get_full_name_by_scope_name(ir_model, name.split('.'), module_name)
lc_node = ir_model.get_node_by_name(full_name)
assert lc_node is not None, f'Cannot find a node with name {full_name}'
for cand_name in submodule.names:
cand = submodule[cand_name]
......@@ -761,12 +763,14 @@ class GraphConverterWithShape(GraphConverter):
return
graph_node = ir_model.get_node_by_name(graph.name)
assert graph_node is not None, f'Cannot find a node with name {graph.name}'
if not _without_shape_info(graph_node):
return
if is_layerchoice_node(graph_node):
cand_name = graph_node.operation.parameters['candidates'][0]
cand_node = ir_model.get_node_by_name(cand_name)
assert cand_node is not None, f'Cannot find a node with name {cand_name}'
if _without_shape_info(cand_node):
propagate_shape_for_graph(ir_model.graphs[cand_name])
graph_node.operation.attributes['input_shape'] = cand_node.operation.attributes['input_shape']
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional
from typing_extensions import TypeGuard
from ..operation import Cell
from ..graph import Model, Graph, Node, Edge
......@@ -77,7 +81,7 @@ def _extract_info_from_trace_node(trace_node):
return shape_parameters, None
def is_layerchoice_node(ir_node: Node):
def is_layerchoice_node(ir_node: Optional[Node]) -> TypeGuard[Node]:
if ir_node is not None and isinstance(ir_node.operation, Cell) and ir_node.operation.parameters.get('mutation') == 'layerchoice':
return True
else:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# we will support tensorflow in future release
framework = 'pytorch'
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .lightning import *
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import warnings
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Type
import torch.nn as nn
......@@ -18,7 +18,7 @@ from .trainer import Trainer
@nni.trace
class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, torchmetrics.Metric],
n_models: int = 0,
learning_rate: float = 0.001,
weight_decay: float = 0.,
......@@ -171,7 +171,7 @@ class Classification(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
......@@ -184,7 +184,7 @@ class Classification(Lightning):
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
class _RegressionModule(_MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment