Unverified Commit 37b73e4e authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix pipeline (#4721)

parent 303862a3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .one_shot_pruner import * from .one_shot_pruner import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf import tensorflow as tf
from nni.compression.tensorflow import Pruner from nni.compression.tensorflow import Pruner
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .compressor import Compressor, LayerInfo from .compressor import Compressor, LayerInfo
from .pruner import Pruner, PrunerModuleWrapper from .pruner import Pruner, PrunerModuleWrapper
from .scheduler import BasePruningScheduler, Task, TaskResult from .scheduler import BasePruningScheduler, Task, TaskResult
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .basic_pruner import * from .basic_pruner import *
from .basic_scheduler import PruningScheduler from .basic_scheduler import PruningScheduler
from .iterative_pruner import * from .iterative_pruner import *
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base import ( from .base import (
HookCollectorInfo, HookCollectorInfo,
DataCollector, DataCollector,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .agent import DDPG from .agent import DDPG
from .amc_env import AMCEnv from .amc_env import AMCEnv
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .config_validation import CompressorSchema from .config_validation import CompressorSchema
from .pruning import ( from .pruning import (
config_list_canonical, config_list_canonical,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .gbdt_selector import GBDTSelector from .gbdt_selector import GBDTSelector
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .gradient_selector import FeatureGradientSelector from .gradient_selector import FeatureGradientSelector
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from enum import Enum, EnumMeta from enum import Enum, EnumMeta
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from torch.quantization import default_weight_observer, default_histogram_observer from torch.quantization import default_weight_observer, default_histogram_observer
from torch.quantization import RecordingObserver as _RecordingObserver from torch.quantization import RecordingObserver as _RecordingObserver
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Optional from typing import Any, Optional
from .literal import QuantDtype, QuantType, QuantScheme from .literal import QuantDtype, QuantType, QuantScheme
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch import torch
from nni.common.version import TORCH_VERSION from nni.common.version import TORCH_VERSION
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .integrated_tensorrt import CalibrateType, ModelSpeedupTensorRT from .integrated_tensorrt import CalibrateType, ModelSpeedupTensorRT
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .compressor import ModelSpeedup from .compressor import ModelSpeedup
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .utils import * from .utils import *
from .shape_dependency import * from .shape_dependency import *
from .shape_dependency import ReshapeDependency from .shape_dependency import ReshapeDependency
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
def get_total_num_weights(model, op_types=['default']): def get_total_num_weights(model, op_types=['default']):
''' '''
calculate the total number of weights calculate the total number of weights
......
...@@ -120,7 +120,11 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix): ...@@ -120,7 +120,11 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
link = Path(config.experiment_working_directory, '_latest') link = Path(config.experiment_working_directory, '_latest')
try: try:
link.unlink(missing_ok=True) if sys.version_info >= (3, 8):
link.unlink(missing_ok=True)
else:
if link.exists():
link.unlink()
link.symlink_to(exp_id, target_is_directory=True) link.symlink_to(exp_id, target_is_directory=True)
except Exception: except Exception:
if sys.platform != 'win32': if sys.platform != 'win32':
......
...@@ -223,7 +223,16 @@ def copy_nni_node(version): ...@@ -223,7 +223,16 @@ def copy_nni_node(version):
""" """
_print('Copying files') _print('Copying files')
shutil.copytree('ts/nni_manager/dist', 'nni_node', dirs_exist_ok=True) if sys.version_info >= (3, 8):
shutil.copytree('ts/nni_manager/dist', 'nni_node', dirs_exist_ok=True)
else:
for item in os.listdir('ts/nni_manager/dist'):
subsrc = os.path.join('ts/nni_manager/dist', item)
subdst = os.path.join('nni_node', item)
if os.path.isdir(subsrc):
shutil.copytree(subsrc, subdst)
else:
shutil.copy2(subsrc, subdst)
shutil.copyfile('ts/nni_manager/yarn.lock', 'nni_node/yarn.lock') shutil.copyfile('ts/nni_manager/yarn.lock', 'nni_node/yarn.lock')
Path('nni_node/nni_manager.tsbuildinfo').unlink() Path('nni_node/nni_manager.tsbuildinfo').unlink()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment