"test/batched_gemm/batched_gemm_int8.cpp" did not exist on "abf4bdb9a9946c578d4801a79650e79938fb0e41"
Unverified Commit 18962129 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add license header and typehints for NAS (#4774)

parent 8c2f717d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools import functools
from peewee import fn from peewee import fn
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .model import NlpTrialStats, NlpIntermediateStats, NlpTrialConfig from .model import NlpTrialStats, NlpIntermediateStats, NlpTrialConfig
from .query import query_nlp_trial_stats from .query import query_nlp_trial_stats
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json import json
import os import os
import argparse import argparse
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os import os
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools import functools
from peewee import fn from peewee import fn
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools import functools
import hashlib import hashlib
import json import json
...@@ -6,6 +9,7 @@ import os ...@@ -6,6 +9,7 @@ import os
import shutil import shutil
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Optional
import requests import requests
import tqdm import tqdm
...@@ -49,8 +53,9 @@ def load_or_download_file(local_path: str, download_url: str, download: bool = F ...@@ -49,8 +53,9 @@ def load_or_download_file(local_path: str, download_url: str, download: bool = F
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
r = requests.get(download_url, stream=True) r = requests.get(download_url, stream=True)
total_length = int(r.headers.get('content-length')) total_length: Optional[str] = r.headers.get('content-length')
with tqdm.tqdm(total=total_length, disable=not progress, assert total_length is not None, f'Content length is not found in the response of {download_url}'
with tqdm.tqdm(total=int(total_length), disable=not progress,
unit='B', unit_scale=True, unit_divisor=1024) as pbar: unit='B', unit_scale=True, unit_divisor=1024) as pbar:
for chunk in r.iter_content(8192): for chunk in r.iter_content(8192):
f.write(chunk) f.write(chunk)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base_mutator import BaseMutator from .base_mutator import BaseMutator
from .base_trainer import BaseTrainer from .base_trainer import BaseTrainer
from .fixed import apply_fixed_architecture from .fixed import apply_fixed_architecture
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .nasbench201 import NASBench201Cell from .nasbench201 import NASBench201Cell
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections import OrderedDict from collections import OrderedDict
import torch.nn as nn import torch.nn as nn
from nni.nas.pytorch.mutables import LayerChoice from nni.nas.pytorch.mutables import LayerChoice
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch import torch
import torch.nn as nn import torch.nn as nn
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .darts_cell import DartsCell from .darts_cell import DartsCell
from .enas_cell import ENASMicroLayer from .enas_cell import ENASMicroLayer
from .enas_cell import ENASMacroLayer from .enas_cell import ENASMacroLayer
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .pytorch import model_to_pytorch_script from .pytorch import model_to_pytorch_script
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
import logging import logging
import re import re
from typing import Dict, List, Tuple, Any from typing import Dict, List, Tuple, Any, cast
from nni.retiarii.operation import PyTorchOperation
from nni.retiarii.operation_def.torch_op_def import ToDevice from nni.retiarii.operation_def.torch_op_def import ToDevice
from nni.retiarii.utils import STATE_DICT_PY_MAPPING from nni.retiarii.utils import STATE_DICT_PY_MAPPING
from nni.common.device import Device, GPUDevice from nni.common.device import Device, GPUDevice
...@@ -34,7 +35,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]: ...@@ -34,7 +35,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
if all(edge.tail_slot is None for edge in edges): if all(edge.tail_slot is None for edge in edges):
return edges return edges
if all(isinstance(edge.tail_slot, int) for edge in edges): if all(isinstance(edge.tail_slot, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: edge.tail_slot)) edges = sorted(edges, key=(lambda edge: cast(int, edge.tail_slot)))
if [edge.tail_slot for edge in edges] == list(range(len(edges))): if [edge.tail_slot for edge in edges] == list(range(len(edges))):
return edges return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name)) raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
...@@ -98,7 +99,7 @@ def _format_variable_name(name: str, graph_name: str) -> str: ...@@ -98,7 +99,7 @@ def _format_variable_name(name: str, graph_name: str) -> str:
name = name.replace('/', '__') name = name.replace('/', '__')
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python # https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
name = re.sub('\W|^(?=\d)','_', name) name = re.sub(r'\W|^(?=\d)','_', name)
if name.startswith('__') and (len(name) > 2 and name[2] != '_'): if name.startswith('__') and (len(name) > 2 and name[2] != '_'):
# name can't start with double underscore # name can't start with double underscore
...@@ -130,7 +131,7 @@ def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]: ...@@ -130,7 +131,7 @@ def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
return cuda_remapped_id return cuda_remapped_id
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str: def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> Tuple[set, str]:
nodes = graph.topo_sort() nodes = graph.topo_sort()
# handle module node and function node differently # handle module node and function node differently
...@@ -144,11 +145,12 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str ...@@ -144,11 +145,12 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
for node in nodes: for node in nodes:
if node.operation: if node.operation:
if placement and isinstance(node.operation, ToDevice): if placement and isinstance(node.operation, ToDevice):
cuda_remapped_id = cast(dict, cuda_remapped_id)
node.operation.override_device_repr("cuda:%d" % cuda_remapped_id[node.operation.device]) node.operation.override_device_repr("cuda:%d" % cuda_remapped_id[node.operation.device])
if node.operation.type == 'shared': if node.operation.type == 'shared':
continue continue
pkg_name = node.operation.get_import_pkg() pkg_name = cast(PyTorchOperation, node.operation).get_import_pkg()
if pkg_name is not None: if pkg_name is not None:
import_pkgs.add(pkg_name) import_pkgs.add(pkg_name)
...@@ -157,6 +159,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str ...@@ -157,6 +159,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
if node_code is not None: if node_code is not None:
if placement and node in placement and len(node_code) > 0: if placement and node in placement and len(node_code) > 0:
if isinstance(placement[node], GPUDevice): if isinstance(placement[node], GPUDevice):
assert cuda_remapped_id is not None
device_repr = "cuda:%d" % cuda_remapped_id[placement[node]] device_repr = "cuda:%d" % cuda_remapped_id[placement[node]]
else: else:
device_repr = placement[node].device_repr() device_repr = placement[node].device_repr()
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
# pylint: skip-file # pylint: skip-file
# type: ignore
""" """
FIXME FIXME
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .graph_gen import convert_to_graph from .graph_gen import convert_to_graph
...@@ -411,6 +411,7 @@ class GraphConverter: ...@@ -411,6 +411,7 @@ class GraphConverter:
edge.graph = ir_graph edge.graph = ir_graph
if edge.head == method_ir_graph.input_node: if edge.head == method_ir_graph.input_node:
# this is a member method, 'self' is the first argument, thus +1 # this is a member method, 'self' is the first argument, thus +1
assert edge.head_slot is not None
_input = node.inputsAt(edge.head_slot + 1) _input = node.inputsAt(edge.head_slot + 1)
src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index) src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
edge.head = src_node edge.head = src_node
...@@ -745,6 +746,7 @@ class GraphConverterWithShape(GraphConverter): ...@@ -745,6 +746,7 @@ class GraphConverterWithShape(GraphConverter):
if isinstance(submodule, LayerChoice): if isinstance(submodule, LayerChoice):
full_name = get_full_name_by_scope_name(ir_model, name.split('.'), module_name) full_name = get_full_name_by_scope_name(ir_model, name.split('.'), module_name)
lc_node = ir_model.get_node_by_name(full_name) lc_node = ir_model.get_node_by_name(full_name)
assert lc_node is not None, f'Cannot find a node with name {full_name}'
for cand_name in submodule.names: for cand_name in submodule.names:
cand = submodule[cand_name] cand = submodule[cand_name]
...@@ -761,12 +763,14 @@ class GraphConverterWithShape(GraphConverter): ...@@ -761,12 +763,14 @@ class GraphConverterWithShape(GraphConverter):
return return
graph_node = ir_model.get_node_by_name(graph.name) graph_node = ir_model.get_node_by_name(graph.name)
assert graph_node is not None, f'Cannot find a node with name {graph.name}'
if not _without_shape_info(graph_node): if not _without_shape_info(graph_node):
return return
if is_layerchoice_node(graph_node): if is_layerchoice_node(graph_node):
cand_name = graph_node.operation.parameters['candidates'][0] cand_name = graph_node.operation.parameters['candidates'][0]
cand_node = ir_model.get_node_by_name(cand_name) cand_node = ir_model.get_node_by_name(cand_name)
assert cand_node is not None, f'Cannot find a node with name {cand_name}'
if _without_shape_info(cand_node): if _without_shape_info(cand_node):
propagate_shape_for_graph(ir_model.graphs[cand_name]) propagate_shape_for_graph(ir_model.graphs[cand_name])
graph_node.operation.attributes['input_shape'] = cand_node.operation.attributes['input_shape'] graph_node.operation.attributes['input_shape'] = cand_node.operation.attributes['input_shape']
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import Optional
from typing_extensions import TypeGuard
from ..operation import Cell from ..operation import Cell
from ..graph import Model, Graph, Node, Edge from ..graph import Model, Graph, Node, Edge
...@@ -77,7 +81,7 @@ def _extract_info_from_trace_node(trace_node): ...@@ -77,7 +81,7 @@ def _extract_info_from_trace_node(trace_node):
return shape_parameters, None return shape_parameters, None
def is_layerchoice_node(ir_node: Node): def is_layerchoice_node(ir_node: Optional[Node]) -> TypeGuard[Node]:
if ir_node is not None and isinstance(ir_node.operation, Cell) and ir_node.operation.parameters.get('mutation') == 'layerchoice': if ir_node is not None and isinstance(ir_node.operation, Cell) and ir_node.operation.parameters.get('mutation') == 'layerchoice':
return True return True
else: else:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# we will support tensorflow in future release # we will support tensorflow in future release
framework = 'pytorch' framework = 'pytorch'
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .lightning import * from .lightning import *
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import warnings import warnings
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union, Type
import torch.nn as nn import torch.nn as nn
...@@ -18,7 +18,7 @@ from .trainer import Trainer ...@@ -18,7 +18,7 @@ from .trainer import Trainer
@nni.trace @nni.trace
class _MultiModelSupervisedLearningModule(LightningModule): class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric], def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, torchmetrics.Metric],
n_models: int = 0, n_models: int = 0,
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
...@@ -171,7 +171,7 @@ class Classification(Lightning): ...@@ -171,7 +171,7 @@ class Classification(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details. `Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
""" """
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss, def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam, optimizer: optim.Optimizer = optim.Adam,
...@@ -184,7 +184,7 @@ class Classification(Lightning): ...@@ -184,7 +184,7 @@ class Classification(Lightning):
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
class _RegressionModule(_MultiModelSupervisedLearningModule): class _RegressionModule(_MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss, def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam): optimizer: optim.Optimizer = optim.Adam):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment