Unverified Commit bbda6a8a authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Replace tensorboardX with torch.utils.tensorboard (#2786)

parent f6991e8a
...@@ -28,7 +28,6 @@ jobs: ...@@ -28,7 +28,6 @@ jobs:
set -e set -e
sudo apt-get install -y pandoc sudo apt-get install -y pandoc
python3 -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user python3 -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==2.2.0 --user python3 -m pip install tensorflow==2.2.0 --user
python3 -m pip install keras==2.4.2 --user python3 -m pip install keras==2.4.2 --user
python3 -m pip install gym onnx peewee thop --user python3 -m pip install gym onnx peewee thop --user
...@@ -69,7 +68,6 @@ jobs: ...@@ -69,7 +68,6 @@ jobs:
- script: | - script: |
set -e set -e
python3 -m pip install torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html --user python3 -m pip install torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==1.15.2 --user python3 -m pip install tensorflow==1.15.2 --user
python3 -m pip install keras==2.1.6 --user python3 -m pip install keras==2.1.6 --user
python3 -m pip install gym onnx peewee --user python3 -m pip install gym onnx peewee --user
...@@ -119,7 +117,6 @@ jobs: ...@@ -119,7 +117,6 @@ jobs:
set -e set -e
# pytorch Mac binary does not support CUDA, default is cpu version # pytorch Mac binary does not support CUDA, default is cpu version
python3 -m pip install torchvision==0.6.0 torch==1.5.0 --user python3 -m pip install torchvision==0.6.0 torch==1.5.0 --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==1.15.2 --user python3 -m pip install tensorflow==1.15.2 --user
brew install swig@3 brew install swig@3
rm -f /usr/local/bin/swig rm -f /usr/local/bin/swig
...@@ -147,7 +144,6 @@ jobs: ...@@ -147,7 +144,6 @@ jobs:
python -m pip install scikit-learn==0.23.2 --user python -m pip install scikit-learn==0.23.2 --user
python -m pip install keras==2.1.6 --user python -m pip install keras==2.1.6 --user
python -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user python -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python -m pip install tensorboardX==1.9
python -m pip install tensorflow==1.15.2 --user python -m pip install tensorflow==1.15.2 --user
displayName: 'Install dependencies' displayName: 'Install dependencies'
- script: | - script: |
......
...@@ -6,6 +6,7 @@ from copy import deepcopy ...@@ -6,6 +6,7 @@ from copy import deepcopy
from argparse import Namespace from argparse import Namespace
import numpy as np import numpy as np
import torch import torch
from torch.utils.tensorboard import SummaryWriter
from nni.compression.torch.compressor import Pruner from nni.compression.torch.compressor import Pruner
from .channel_pruning_env import ChannelPruningEnv from .channel_pruning_env import ChannelPruningEnv
...@@ -148,8 +149,6 @@ class AMCPruner(Pruner): ...@@ -148,8 +149,6 @@ class AMCPruner(Pruner):
epsilon=50000, epsilon=50000,
seed=None): seed=None):
from tensorboardX import SummaryWriter
self.job = job self.job = job
self.searched_model_path = searched_model_path self.searched_model_path = searched_model_path
self.export_path = export_path self.export_path = export_path
...@@ -189,7 +188,7 @@ class AMCPruner(Pruner): ...@@ -189,7 +188,7 @@ class AMCPruner(Pruner):
if self.job == 'train_export': if self.job == 'train_export':
print('=> Saving logs to {}'.format(self.output_dir)) print('=> Saving logs to {}'.format(self.output_dir))
self.tfwriter = SummaryWriter(logdir=self.output_dir) self.tfwriter = SummaryWriter(log_dir=self.output_dir)
self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w') self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w')
print('=> Output path: {}...'.format(self.output_dir)) print('=> Output path: {}...'.format(self.output_dir))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment