Unverified Commit bbda6a8a authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Replace tensorboardX with torch.utils.tensorboard (#2786)

parent f6991e8a
......@@ -28,7 +28,6 @@ jobs:
set -e
sudo apt-get install -y pandoc
python3 -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==2.2.0 --user
python3 -m pip install keras==2.4.2 --user
python3 -m pip install gym onnx peewee thop --user
......@@ -69,7 +68,6 @@ jobs:
- script: |
set -e
python3 -m pip install torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==1.15.2 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install gym onnx peewee --user
......@@ -119,7 +117,6 @@ jobs:
set -e
# pytorch Mac binary does not support CUDA, default is cpu version
python3 -m pip install torchvision==0.6.0 torch==1.5.0 --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==1.15.2 --user
brew install swig@3
rm -f /usr/local/bin/swig
......@@ -147,7 +144,6 @@ jobs:
python -m pip install scikit-learn==0.23.2 --user
python -m pip install keras==2.1.6 --user
python -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python -m pip install tensorboardX==1.9
python -m pip install tensorflow==1.15.2 --user
displayName: 'Install dependencies'
- script: |
......
......@@ -6,6 +6,7 @@ from copy import deepcopy
from argparse import Namespace
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from nni.compression.torch.compressor import Pruner
from .channel_pruning_env import ChannelPruningEnv
......@@ -148,8 +149,6 @@ class AMCPruner(Pruner):
epsilon=50000,
seed=None):
from tensorboardX import SummaryWriter
self.job = job
self.searched_model_path = searched_model_path
self.export_path = export_path
......@@ -189,7 +188,7 @@ class AMCPruner(Pruner):
if self.job == 'train_export':
print('=> Saving logs to {}'.format(self.output_dir))
self.tfwriter = SummaryWriter(logdir=self.output_dir)
self.tfwriter = SummaryWriter(log_dir=self.output_dir)
self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w')
print('=> Output path: {}...'.format(self.output_dir))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment