Unverified Commit 7a1f05ae authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Merge pull request #3444 from microsoft/v2.1

V2.1 merge back to master
parents 539a7cd7 a0ae02e6
...@@ -29,7 +29,10 @@ def canonical_path(path: Optional[PathLike]) -> Optional[str]: ...@@ -29,7 +29,10 @@ def canonical_path(path: Optional[PathLike]) -> Optional[str]:
def count(*values) -> int: def count(*values) -> int:
return sum(value is not None and value is not False for value in values) return sum(value is not None and value is not False for value in values)
def training_service_config_factory(platform: Union[str, List[str]] = None, config: Union[List, Dict] = None): # -> TrainingServiceConfig def training_service_config_factory(
platform: Union[str, List[str]] = None,
config: Union[List, Dict] = None,
base_path: Optional[Path] = None): # -> TrainingServiceConfig
from .common import TrainingServiceConfig from .common import TrainingServiceConfig
ts_configs = [] ts_configs = []
if platform is not None: if platform is not None:
...@@ -47,7 +50,7 @@ def training_service_config_factory(platform: Union[str, List[str]] = None, conf ...@@ -47,7 +50,7 @@ def training_service_config_factory(platform: Union[str, List[str]] = None, conf
for conf in configs: for conf in configs:
if conf['platform'] not in supported_platforms: if conf['platform'] not in supported_platforms:
raise RuntimeError(f'Unrecognized platform {conf["platform"]}') raise RuntimeError(f'Unrecognized platform {conf["platform"]}')
ts_configs.append(supported_platforms[conf['platform']](**conf)) ts_configs.append(supported_platforms[conf['platform']](_base_path=base_path, **conf))
return ts_configs if len(ts_configs) > 1 else ts_configs[0] return ts_configs if len(ts_configs) > 1 else ts_configs[0]
def load_config(Type, value): def load_config(Type, value):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .operation import Operation from .operation import Operation
from .graph import * from .graph import *
from .execution import * from .execution import *
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging import logging
from typing import List, Tuple, Any from typing import List, Tuple, Any
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file # pylint: skip-file
""" """
......
# PyTorch Graph Converter
## Namespace for PyTorch Graph
We should have a concrete rule for specifying nodes in graph with namespace.
Each node has a name, either specified or generated. The nodes in the same hierarchy cannot have the same name.
* The name of module node natively follows this rule, because we use variable name for instantiated modules like what PyTorch graph does.
* For the nodes created in `forward` function, we use a global sequence number.
### Namespace for mutated (new) nodes
TBD
## Graph Simplification
TBD
## Node Types
We define concrete type string for each node type.
## Module's Input Arguments
We use wrapper to obtain the input arguments of modules. Users need to use our wrapped "nn" and wrapped "Module".
## Control Flow
### for loop
Currently, we only support `ModuleList` (`ModuleDict`) based for loop, which is automatically unfolded by TorchScript. That is to say, we do not support loop in TorchScript for now.
### if/else
For now, we only deal with the case that the condition is constant or attribute. In this case, only one branch is kept during generating the graph.
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import re import re
import torch import torch
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from enum import Enum from enum import Enum
MODULE_EXCEPT_LIST = ['Sequential'] MODULE_EXCEPT_LIST = ['Sequential']
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
def build_full_name(prefix, name, seq=None): def build_full_name(prefix, name, seq=None):
if isinstance(name, list): if isinstance(name, list):
name = '__'.join(name) name = '__'.join(name)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import graphviz import graphviz
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from ..graph import Evaluator from ..graph import Evaluator
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# This file is deprecated. # This file is deprecated.
import abc import abc
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import warnings import warnings
from typing import Dict, Union, Optional, List from typing import Dict, Union, Optional, List
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import time import time
from ..graph import Model, ModelStatus from ..graph import Model, ModelStatus
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging import logging
import os import os
import random import random
...@@ -63,13 +66,14 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -63,13 +66,14 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def _send_trial_callback(self, paramater: dict) -> None: def _send_trial_callback(self, paramater: dict) -> None:
if self.resources <= 0: if self.resources <= 0:
_logger.warning('There is no available resource, but trial is submitted.') # FIXME: should be a warning message here
_logger.debug('There is no available resource, but trial is submitted.')
self.resources -= 1 self.resources -= 1
_logger.info('Resource used. Remaining: %d', self.resources) _logger.debug('Resource used. Remaining: %d', self.resources)
def _request_trial_jobs_callback(self, num_trials: int) -> None: def _request_trial_jobs_callback(self, num_trials: int) -> None:
self.resources += num_trials self.resources += num_trials
_logger.info('New resource available. Remaining: %d', self.resources) _logger.debug('New resource available. Remaining: %d', self.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None: def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id] model = self._running_models[trial_id]
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging import logging
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC, abstractmethod, abstractclassmethod from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, NewType, List, Union from typing import Any, NewType, List, Union
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from ..graph import Model, ModelStatus from ..graph import Model, ModelStatus
from .interface import MetricData, AbstractGraphListener from .interface import MetricData, AbstractGraphListener
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC from abc import ABC
from .logical_plan import LogicalPlan from .logical_plan import LogicalPlan
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy import copy
from typing import Dict, Tuple, List, Any from typing import Dict, Tuple, List, Any
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
from nni.retiarii.utils import uid from nni.retiarii.utils import uid
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment