Unverified Commit 67287997 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #241 from microsoft/master

merge master
parents b4773e1e f8d42a33
...@@ -212,7 +212,7 @@ class AGP_Pruner(Pruner): ...@@ -212,7 +212,7 @@ class AGP_Pruner(Pruner):
if epoch > 0: if epoch > 0:
self.now_epoch = epoch self.now_epoch = epoch
for wrapper in self.get_modules_wrapper(): for wrapper in self.get_modules_wrapper():
wrapper.if_calculated.copy_(torch.tensor(0)) # pylint: disable=not-callable wrapper.if_calculated = False
class SlimPruner(Pruner): class SlimPruner(Pruner):
""" """
...@@ -329,10 +329,6 @@ class LotteryTicketPruner(Pruner): ...@@ -329,10 +329,6 @@ class LotteryTicketPruner(Pruner):
reset_weights : bool reset_weights : bool
Whether reset weights and optimizer at the beginning of each round. Whether reset weights and optimizer at the beginning of each round.
""" """
super().__init__(model, config_list, optimizer)
self.curr_prune_iteration = None
self.prune_iterations = config_list[0]['prune_iterations']
# save init weights and optimizer # save init weights and optimizer
self.reset_weights = reset_weights self.reset_weights = reset_weights
if self.reset_weights: if self.reset_weights:
...@@ -344,6 +340,10 @@ class LotteryTicketPruner(Pruner): ...@@ -344,6 +340,10 @@ class LotteryTicketPruner(Pruner):
if lr_scheduler is not None: if lr_scheduler is not None:
self._scheduler_state = copy.deepcopy(lr_scheduler.state_dict()) self._scheduler_state = copy.deepcopy(lr_scheduler.state_dict())
super().__init__(model, config_list, optimizer)
self.curr_prune_iteration = None
self.prune_iterations = config_list[0]['prune_iterations']
def validate_config(self, model, config_list): def validate_config(self, model, config_list):
""" """
Parameters Parameters
......
...@@ -129,6 +129,7 @@ class DartsTrainer(Trainer): ...@@ -129,6 +129,7 @@ class DartsTrainer(Trainer):
self.mutator.reset() self.mutator.reset()
logits = self.model(X) logits = self.model(X)
loss = self.loss(logits, y) loss = self.loss(logits, y)
self._write_graph_status()
return logits, loss return logits, loss
def _backward(self, val_X, val_y): def _backward(self, val_X, val_y):
......
...@@ -126,6 +126,7 @@ class EnasTrainer(Trainer): ...@@ -126,6 +126,7 @@ class EnasTrainer(Trainer):
with torch.no_grad(): with torch.no_grad():
self.mutator.reset() self.mutator.reset()
self._write_graph_status()
logits = self.model(x) logits = self.model(x)
if isinstance(logits, tuple): if isinstance(logits, tuple):
...@@ -159,6 +160,7 @@ class EnasTrainer(Trainer): ...@@ -159,6 +160,7 @@ class EnasTrainer(Trainer):
self.mutator.reset() self.mutator.reset()
with torch.no_grad(): with torch.no_grad():
logits = self.model(x) logits = self.model(x)
self._write_graph_status()
metrics = self.metrics(logits, y) metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y) reward = self.reward_function(logits, y)
if self.entropy_weight: if self.entropy_weight:
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
import json import json
import logging import logging
import os
import time
from abc import abstractmethod from abc import abstractmethod
import torch import torch
...@@ -90,6 +92,9 @@ class Trainer(BaseTrainer): ...@@ -90,6 +92,9 @@ class Trainer(BaseTrainer):
self.batch_size = batch_size self.batch_size = batch_size
self.workers = workers self.workers = workers
self.log_frequency = log_frequency self.log_frequency = log_frequency
self.log_dir = os.path.join("logs", str(time.time()))
os.makedirs(self.log_dir, exist_ok=True)
self.status_writer = open(os.path.join(self.log_dir, "log"), "w")
self.callbacks = callbacks if callbacks is not None else [] self.callbacks = callbacks if callbacks is not None else []
for callback in self.callbacks: for callback in self.callbacks:
callback.build(self.model, self.mutator, self) callback.build(self.model, self.mutator, self)
...@@ -168,3 +173,22 @@ class Trainer(BaseTrainer): ...@@ -168,3 +173,22 @@ class Trainer(BaseTrainer):
Return trainer checkpoint. Return trainer checkpoint.
""" """
raise NotImplementedError("Not implemented yet") raise NotImplementedError("Not implemented yet")
def enable_visualization(self):
"""
Enable visualization. Write graph and training log to folder ``logs/<timestamp>``.
"""
sample = None
for x, _ in self.train_loader:
sample = x.to(self.device)[:2]
break
if sample is None:
_logger.warning("Sample is %s.", sample)
_logger.info("Creating graph json, writing to %s. Visualization enabled.", self.log_dir)
with open(os.path.join(self.log_dir, "graph.json"), "w") as f:
json.dump(self.mutator.graph(sample), f)
self.visualization_enabled = True
def _write_graph_status(self):
if hasattr(self, "visualization_enabled") and self.visualization_enabled:
print(json.dumps(self.mutator.status()), file=self.status_writer, flush=True)
...@@ -92,5 +92,8 @@ ...@@ -92,5 +92,8 @@
"presets": [ "presets": [
"react-app" "react-app"
] ]
},
"resolutions": {
"npm": ">=6.14.4"
} }
} }
...@@ -17,8 +17,9 @@ interface AppState { ...@@ -17,8 +17,9 @@ interface AppState {
} }
class App extends React.Component<{}, AppState> { class App extends React.Component<{}, AppState> {
private timerId!: number | null; private timerId!: number | undefined;
private dataFormatimer!: number; private dataFormatimer!: number;
private firstLoad: boolean = false; // when click refresh selector options
constructor(props: {}) { constructor(props: {}) {
super(props); super(props);
...@@ -66,14 +67,20 @@ class App extends React.Component<{}, AppState> { ...@@ -66,14 +67,20 @@ class App extends React.Component<{}, AppState> {
} }
} }
} }
changeInterval = (interval: number): void => { changeInterval = (interval: number): void => {
this.setState({ interval });
if (this.timerId === null && interval !== 0) { window.clearTimeout(this.timerId);
window.setTimeout(this.refresh); if (interval === 0) {
} else if (this.timerId !== null && interval === 0) { return;
window.clearTimeout(this.timerId);
} }
// setState will trigger page refresh at once.
// setState is asyc, interval not update to (this.state.interval) at once.
this.setState({interval}, () => {
this.firstLoad = true;
this.refresh();
});
} }
// TODO: use local storage // TODO: use local storage
...@@ -123,24 +130,30 @@ class App extends React.Component<{}, AppState> { ...@@ -123,24 +130,30 @@ class App extends React.Component<{}, AppState> {
} }
private refresh = async (): Promise<void> => { private refresh = async (): Promise<void> => {
const [experimentUpdated, trialsUpdated] = await Promise.all([EXPERIMENT.update(), TRIALS.update()]);
if (experimentUpdated) { // resolve this question: 10s -> 20s, page refresh twice.
this.setState(state => ({ experimentUpdateBroadcast: state.experimentUpdateBroadcast + 1 })); // only refresh this page after clicking the refresh options
} if (this.firstLoad !== true) {
if (trialsUpdated) { const [experimentUpdated, trialsUpdated] = await Promise.all([EXPERIMENT.update(), TRIALS.update()]);
this.setState(state => ({ trialsUpdateBroadcast: state.trialsUpdateBroadcast + 1 })); if (experimentUpdated) {
this.setState(state => ({ experimentUpdateBroadcast: state.experimentUpdateBroadcast + 1 }));
}
if (trialsUpdated) {
this.setState(state => ({ trialsUpdateBroadcast: state.trialsUpdateBroadcast + 1 }));
}
} else {
this.firstLoad = false;
} }
if (['DONE', 'ERROR', 'STOPPED'].includes(EXPERIMENT.status)) { if (['DONE', 'ERROR', 'STOPPED'].includes(EXPERIMENT.status)) {
// experiment finished, refresh once more to ensure consistency // experiment finished, refresh once more to ensure consistency
if (this.state.interval > 0) { this.setState({ interval: 0 });
this.setState({ interval: 0 }); this.lastRefresh();
this.lastRefresh(); return;
}
} else if (this.state.interval !== 0) {
this.timerId = window.setTimeout(this.refresh, this.state.interval * 1000);
} }
this.timerId = window.setTimeout(this.refresh, this.state.interval * 1000);
} }
public async lastRefresh(): Promise<void> { public async lastRefresh(): Promise<void> {
......
This diff is collapsed.
import time
import nni
if __name__ == '__main__':
print('trial start')
params = nni.get_next_parameter()
print('params:', params)
epochs = 2
for i in range(epochs):
nni.report_intermediate_result(0.1 * (i+1))
time.sleep(1)
nni.report_final_result(0.8)
print('trial done')
...@@ -70,12 +70,18 @@ testCases: ...@@ -70,12 +70,18 @@ testCases:
config: config:
maxTrialNum: 2 maxTrialNum: 2
trialConcurrency: 2 trialConcurrency: 2
trial:
codeDir: ../naive_trial
command: python3 naive_trial.py
- name: assessor-medianstop - name: assessor-medianstop
configFile: test/config/assessors/medianstop.yml configFile: test/config/assessors/medianstop.yml
config: config:
maxTrialNum: 2 maxTrialNum: 2
trialConcurrency: 2 trialConcurrency: 2
trial:
codeDir: ../naive_trial
command: python3 naive_trial.py
######################################################################### #########################################################################
# nni tuners test # nni tuners test
...@@ -89,7 +95,7 @@ testCases: ...@@ -89,7 +95,7 @@ testCases:
searchSpacePath: ../naive_trial/search_space.json searchSpacePath: ../naive_trial/search_space.json
trial: trial:
codeDir: ../naive_trial codeDir: ../naive_trial
command: python3 trial.py command: python3 naive_trial.py
- name: tuner-evolution - name: tuner-evolution
configFile: test/config/tuners/evolution.yml configFile: test/config/tuners/evolution.yml
...@@ -100,7 +106,7 @@ testCases: ...@@ -100,7 +106,7 @@ testCases:
searchSpacePath: ../naive_trial/search_space.json searchSpacePath: ../naive_trial/search_space.json
trial: trial:
codeDir: ../naive_trial codeDir: ../naive_trial
command: python3 trial.py command: python3 naive_trial.py
- name: tuner-random - name: tuner-random
configFile: test/config/tuners/random.yml configFile: test/config/tuners/random.yml
...@@ -111,7 +117,7 @@ testCases: ...@@ -111,7 +117,7 @@ testCases:
searchSpacePath: ../naive_trial/search_space.json searchSpacePath: ../naive_trial/search_space.json
trial: trial:
codeDir: ../naive_trial codeDir: ../naive_trial
command: python3 trial.py command: python3 naive_trial.py
- name: tuner-tpe - name: tuner-tpe
configFile: test/config/tuners/tpe.yml configFile: test/config/tuners/tpe.yml
...@@ -122,7 +128,7 @@ testCases: ...@@ -122,7 +128,7 @@ testCases:
searchSpacePath: ../naive_trial/search_space.json searchSpacePath: ../naive_trial/search_space.json
trial: trial:
codeDir: ../naive_trial codeDir: ../naive_trial
command: python3 trial.py command: python3 naive_trial.py
- name: tuner-batch - name: tuner-batch
configFile: test/config/tuners/batch.yml configFile: test/config/tuners/batch.yml
...@@ -144,7 +150,7 @@ testCases: ...@@ -144,7 +150,7 @@ testCases:
searchSpacePath: ../naive_trial/search_space.json searchSpacePath: ../naive_trial/search_space.json
trial: trial:
codeDir: ../naive_trial codeDir: ../naive_trial
command: python3 trial.py command: python3 naive_trial.py
- name: tuner-grid - name: tuner-grid
configFile: test/config/tuners/gridsearch.yml configFile: test/config/tuners/gridsearch.yml
......
...@@ -10,7 +10,7 @@ jobs: ...@@ -10,7 +10,7 @@ jobs:
python -m pip install scikit-learn==0.20.0 --user python -m pip install scikit-learn==0.20.0 --user
python -m pip install keras==2.1.6 --user python -m pip install keras==2.1.6 --user
python -m pip install torchvision===0.4.1 torch===1.3.1 -f https://download.pytorch.org/whl/torch_stable.html --user python -m pip install torchvision===0.4.1 torch===1.3.1 -f https://download.pytorch.org/whl/torch_stable.html --user
python -m pip install tensorflow-gpu==1.11.0 --user python -m pip install tensorflow-gpu==1.15.2 --user
displayName: 'Install dependencies for integration tests' displayName: 'Install dependencies for integration tests'
- script: | - script: |
cd test cd test
......
...@@ -63,6 +63,7 @@ jobs: ...@@ -63,6 +63,7 @@ jobs:
cd test cd test
set PATH=$(ENV_PATH) set PATH=$(ENV_PATH)
python --version python --version
python nni_test/nnitest/generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) --nni_docker_image $(docker_image) --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip) mount -o anon $(pai_nfs_uri) $(local_nfs_uri)
python nni_test/nnitest/generate_ts_config.py --ts pai --pai_token $(pai_token) --pai_host $(pai_host) --pai_user $(pai_user) --nni_docker_image $(docker_image) --pai_storage_plugin $(pai_storage_plugin) --nni_manager_nfs_mount_path $(nni_manager_nfs_mount_path) --container_nfs_mount_path $(container_nfs_mount_path) --nni_manager_ip $(nni_manager_ip)
python nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts pai --exclude multi-phase python nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts pai --exclude multi-phase
displayName: 'Examples and advanced features tests on pai' displayName: 'Examples and advanced features tests on pai'
\ No newline at end of file
...@@ -52,9 +52,3 @@ jobs: ...@@ -52,9 +52,3 @@ jobs:
runOptions: commands runOptions: commands
commands: python3 /tmp/nnitest/$(Build.BuildId)/nni-remote/test/nni_test/nnitest/remote_docker.py --mode stop --name $(Build.BuildId) --os windows commands: python3 /tmp/nnitest/$(Build.BuildId)/nni-remote/test/nni_test/nnitest/remote_docker.py --mode stop --name $(Build.BuildId) --os windows
displayName: 'Stop docker' displayName: 'Stop docker'
- task: SSH@0
inputs:
sshEndpoint: $(end_point)
runOptions: commands
commands: sudo rm -rf /tmp/nnitest/$(Build.BuildId)
displayName: 'Clean the remote files'
...@@ -19,6 +19,15 @@ do ...@@ -19,6 +19,15 @@ do
python3 model_prune_torch.py --pruner_name $name --pretrain_epochs 1 --prune_epochs 1 python3 model_prune_torch.py --pruner_name $name --pretrain_epochs 1 --prune_epochs 1
done done
echo 'testing level pruner pruning'
python3 model_prune_torch.py --pruner_name level --pretrain_epochs 1 --prune_epochs 1
echo 'testing agp pruning'
python3 model_prune_torch.py --pruner_name agp --pretrain_epochs 1 --prune_epochs 2
echo 'testing mean_activation pruning'
python3 model_prune_torch.py --pruner_name mean_activation --pretrain_epochs 1 --prune_epochs 1
#echo "testing lottery ticket pruning..." #echo "testing lottery ticket pruning..."
#python3 lottery_torch_mnist_fc.py #python3 lottery_torch_mnist_fc.py
......
...@@ -28,9 +28,10 @@ cd $EXAMPLE_DIR/enas ...@@ -28,9 +28,10 @@ cd $EXAMPLE_DIR/enas
python3 search.py --search-for macro --epochs 1 python3 search.py --search-for macro --epochs 1
python3 search.py --search-for micro --epochs 1 python3 search.py --search-for micro --epochs 1
echo "testing naive..." #disabled for now
cd $EXAMPLE_DIR/naive #echo "testing naive..."
python3 train.py #cd $EXAMPLE_DIR/naive
#python3 train.py
echo "testing pdarts..." echo "testing pdarts..."
cd $EXAMPLE_DIR/pdarts cd $EXAMPLE_DIR/pdarts
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment