Unverified Commit 37b73e4e authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix pipeline (#4721)

parent 303862a3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .one_shot_pruner import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf
from nni.compression.tensorflow import Pruner
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .compressor import Compressor, LayerInfo
from .pruner import Pruner, PrunerModuleWrapper
from .scheduler import BasePruningScheduler, Task, TaskResult
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .basic_pruner import *
from .basic_scheduler import PruningScheduler
from .iterative_pruner import *
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base import (
HookCollectorInfo,
DataCollector,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .agent import DDPG
from .amc_env import AMCEnv
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .config_validation import CompressorSchema
from .pruning import (
config_list_canonical,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .gbdt_selector import GBDTSelector
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .gradient_selector import FeatureGradientSelector
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from enum import Enum, EnumMeta
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from torch.quantization import default_weight_observer, default_histogram_observer
from torch.quantization import RecordingObserver as _RecordingObserver
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Optional
from .literal import QuantDtype, QuantType, QuantScheme
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from nni.common.version import TORCH_VERSION
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .integrated_tensorrt import CalibrateType, ModelSpeedupTensorRT
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .compressor import ModelSpeedup
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .utils import *
from .shape_dependency import *
from .shape_dependency import ReshapeDependency
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
def get_total_num_weights(model, op_types=['default']):
'''
calculate the total number of weights
......
......@@ -120,7 +120,11 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
link = Path(config.experiment_working_directory, '_latest')
try:
link.unlink(missing_ok=True)
if sys.version_info >= (3, 8):
link.unlink(missing_ok=True)
else:
if link.exists():
link.unlink()
link.symlink_to(exp_id, target_is_directory=True)
except Exception:
if sys.platform != 'win32':
......
......@@ -223,7 +223,16 @@ def copy_nni_node(version):
"""
_print('Copying files')
shutil.copytree('ts/nni_manager/dist', 'nni_node', dirs_exist_ok=True)
if sys.version_info >= (3, 8):
shutil.copytree('ts/nni_manager/dist', 'nni_node', dirs_exist_ok=True)
else:
for item in os.listdir('ts/nni_manager/dist'):
subsrc = os.path.join('ts/nni_manager/dist', item)
subdst = os.path.join('nni_node', item)
if os.path.isdir(subsrc):
shutil.copytree(subsrc, subdst)
else:
shutil.copy2(subsrc, subdst)
shutil.copyfile('ts/nni_manager/yarn.lock', 'nni_node/yarn.lock')
Path('nni_node/nni_manager.tsbuildinfo').unlink()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment