Unverified Commit 67287997 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #241 from microsoft/master

merge master
parents b4773e1e f8d42a33
......@@ -212,7 +212,7 @@ class AGP_Pruner(Pruner):
if epoch > 0:
self.now_epoch = epoch
for wrapper in self.get_modules_wrapper():
wrapper.if_calculated.copy_(torch.tensor(0)) # pylint: disable=not-callable
wrapper.if_calculated = False
class SlimPruner(Pruner):
"""
......@@ -329,10 +329,6 @@ class LotteryTicketPruner(Pruner):
reset_weights : bool
Whether reset weights and optimizer at the beginning of each round.
"""
super().__init__(model, config_list, optimizer)
self.curr_prune_iteration = None
self.prune_iterations = config_list[0]['prune_iterations']
# save init weights and optimizer
self.reset_weights = reset_weights
if self.reset_weights:
......@@ -344,6 +340,10 @@ class LotteryTicketPruner(Pruner):
if lr_scheduler is not None:
self._scheduler_state = copy.deepcopy(lr_scheduler.state_dict())
super().__init__(model, config_list, optimizer)
self.curr_prune_iteration = None
self.prune_iterations = config_list[0]['prune_iterations']
def validate_config(self, model, config_list):
"""
Parameters
......
......@@ -129,6 +129,7 @@ class DartsTrainer(Trainer):
self.mutator.reset()
logits = self.model(X)
loss = self.loss(logits, y)
self._write_graph_status()
return logits, loss
def _backward(self, val_X, val_y):
......
......@@ -126,6 +126,7 @@ class EnasTrainer(Trainer):
with torch.no_grad():
self.mutator.reset()
self._write_graph_status()
logits = self.model(x)
if isinstance(logits, tuple):
......@@ -159,6 +160,7 @@ class EnasTrainer(Trainer):
self.mutator.reset()
with torch.no_grad():
logits = self.model(x)
self._write_graph_status()
metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y)
if self.entropy_weight:
......
......@@ -3,6 +3,8 @@
import json
import logging
import os
import time
from abc import abstractmethod
import torch
......@@ -90,6 +92,9 @@ class Trainer(BaseTrainer):
self.batch_size = batch_size
self.workers = workers
self.log_frequency = log_frequency
self.log_dir = os.path.join("logs", str(time.time()))
os.makedirs(self.log_dir, exist_ok=True)
self.status_writer = open(os.path.join(self.log_dir, "log"), "w")
self.callbacks = callbacks if callbacks is not None else []
for callback in self.callbacks:
callback.build(self.model, self.mutator, self)
......@@ -168,3 +173,22 @@ class Trainer(BaseTrainer):
Return trainer checkpoint.
"""
raise NotImplementedError("Not implemented yet")
def enable_visualization(self):
"""
Enable visualization. Write graph and training log to folder ``logs/<timestamp>``.
"""
sample = None
for x, _ in self.train_loader:
sample = x.to(self.device)[:2]
break
if sample is None:
_logger.warning("Sample is %s.", sample)
_logger.info("Creating graph json, writing to %s. Visualization enabled.", self.log_dir)
with open(os.path.join(self.log_dir, "graph.json"), "w") as f:
json.dump(self.mutator.graph(sample), f)
self.visualization_enabled = True
def _write_graph_status(self):
if hasattr(self, "visualization_enabled") and self.visualization_enabled:
print(json.dumps(self.mutator.status()), file=self.status_writer, flush=True)
......@@ -92,5 +92,8 @@
"presets": [
"react-app"
]
},
"resolutions": {
"npm": ">=6.14.4"
}
}
......@@ -17,8 +17,9 @@ interface AppState {
}
class App extends React.Component<{}, AppState> {
private timerId!: number | null;
private timerId!: number | undefined;
private dataFormatimer!: number;
private firstLoad: boolean = false; // when click refresh selector options
constructor(props: {}) {
super(props);
......@@ -66,14 +67,20 @@ class App extends React.Component<{}, AppState> {
}
}
}
changeInterval = (interval: number): void => {
this.setState({ interval });
if (this.timerId === null && interval !== 0) {
window.setTimeout(this.refresh);
} else if (this.timerId !== null && interval === 0) {
window.clearTimeout(this.timerId);
window.clearTimeout(this.timerId);
if (interval === 0) {
return;
}
// setState will trigger page refresh at once.
// setState is asyc, interval not update to (this.state.interval) at once.
this.setState({interval}, () => {
this.firstLoad = true;
this.refresh();
});
}
// TODO: use local storage
......@@ -123,24 +130,30 @@ class App extends React.Component<{}, AppState> {
}
private refresh = async (): Promise<void> => {
const [experimentUpdated, trialsUpdated] = await Promise.all([EXPERIMENT.update(), TRIALS.update()]);
if (experimentUpdated) {
this.setState(state => ({ experimentUpdateBroadcast: state.experimentUpdateBroadcast + 1 }));
}
if (trialsUpdated) {
this.setState(state => ({ trialsUpdateBroadcast: state.trialsUpdateBroadcast + 1 }));
// resolve this question: 10s -> 20s, page refresh twice.
// only refresh this page after clicking the refresh options
if (this.firstLoad !== true) {
const [experimentUpdated, trialsUpdated] = await Promise.all([EXPERIMENT.update(), TRIALS.update()]);
if (experimentUpdated) {
this.setState(state => ({ experimentUpdateBroadcast: state.experimentUpdateBroadcast + 1 }));
}
if (trialsUpdated) {
this.setState(state => ({ trialsUpdateBroadcast: state.trialsUpdateBroadcast + 1 }));
}
} else {
this.firstLoad = false;
}
if (['DONE', 'ERROR', 'STOPPED'].includes(EXPERIMENT.status)) {
// experiment finished, refresh once more to ensure consistency
if (this.state.interval > 0) {
this.setState({ interval: 0 });
this.lastRefresh();
}
} else if (this.state.interval !== 0) {
this.timerId = window.setTimeout(this.refresh, this.state.interval * 1000);
this.setState({ interval: 0 });
this.lastRefresh();
return;
}
this.timerId = window.setTimeout(this.refresh, this.state.interval * 1000);
}
public async lastRefresh(): Promise<void> {
......
This diff is collapsed.
import time
import nni
if __name__ == '__main__':
print('trial start')
params = nni.get_next_parameter()
print('params:', params)
epochs = 2
for i in range(epochs):
nni.report_intermediate_result(0.1 * (i+1))
time.sleep(1)
nni.report_final_result(0.8)
print('trial done')
......@@ -70,12 +70,18 @@ testCases:
config:
maxTrialNum: 2
trialConcurrency: 2
trial:
codeDir: ../naive_trial
command: python3 naive_trial.py
- name: assessor-medianstop
configFile: test/config/assessors/medianstop.yml
config:
maxTrialNum: 2
trialConcurrency: 2
trial:
codeDir: ../naive_trial
command: python3 naive_trial.py
#########################################################################
# nni tuners test
......@@ -89,7 +95,7 @@ testCases:
searchSpacePath: ../naive_trial/search_space.json
trial:
codeDir: ../naive_trial
command: python3 trial.py
command: python3 naive_trial.py
- name: tuner-evolution
configFile: test/config/tuners/evolution.yml
......@@ -100,7 +106,7 @@ testCases:
searchSpacePath: ../naive_trial/search_space.json
trial:
codeDir: ../naive_trial
command: python3 trial.py
command: python3 naive_trial.py
- name: tuner-random
configFile: test/config/tuners/random.yml
......@@ -111,7 +117,7 @@ testCases:
searchSpacePath: ../naive_trial/search_space.json
trial:
codeDir: ../naive_trial
command: python3 trial.py
command: python3 naive_trial.py
- name: tuner-tpe
configFile: test/config/tuners/tpe.yml
......@@ -122,7 +128,7 @@ testCases:
searchSpacePath: ../naive_trial/search_space.json
trial:
codeDir: ../naive_trial
command: python3 trial.py
command: python3 naive_trial.py
- name: tuner-batch
configFile: test/config/tuners/batch.yml
......@@ -144,7 +150,7 @@ testCases:
searchSpacePath: ../naive_trial/search_space.json
trial:
codeDir: ../naive_trial
command: python3 trial.py
command: python3 naive_trial.py
- name: tuner-grid
configFile: test/config/tuners/gridsearch.yml
......
......@@ -10,7 +10,7 @@ jobs:
python -m pip install scikit-learn==0.20.0 --user
python -m pip install keras==2.1.6 --user
python -m pip install torchvision===0.4.1 torch===1.3.1 -f https://download.pytorch.org/whl/torch_stable.html --user
python -m pip install tensorflow-gpu==1.11.0 --user
python -m pip install tensorflow-gpu==1.15.2 --user
displayName: 'Install dependencies for integration tests'
- script: |
cd test
......
......@@ -63,6 +63,7 @@ jobs:
cd test
set PATH=$(ENV_PATH)
python --version
python nni_test/nnitest/generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) --nni_docker_image $(docker_image) --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
mount -o anon $(pai_nfs_uri) $(local_nfs_uri)
python nni_test/nnitest/generate_ts_config.py --ts pai --pai_token $(pai_token) --pai_host $(pai_host) --pai_user $(pai_user) --nni_docker_image $(docker_image) --pai_storage_plugin $(pai_storage_plugin) --nni_manager_nfs_mount_path $(nni_manager_nfs_mount_path) --container_nfs_mount_path $(container_nfs_mount_path) --nni_manager_ip $(nni_manager_ip)
python nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts pai --exclude multi-phase
displayName: 'Examples and advanced features tests on pai'
\ No newline at end of file
......@@ -52,9 +52,3 @@ jobs:
runOptions: commands
commands: python3 /tmp/nnitest/$(Build.BuildId)/nni-remote/test/nni_test/nnitest/remote_docker.py --mode stop --name $(Build.BuildId) --os windows
displayName: 'Stop docker'
- task: SSH@0
inputs:
sshEndpoint: $(end_point)
runOptions: commands
commands: sudo rm -rf /tmp/nnitest/$(Build.BuildId)
displayName: 'Clean the remote files'
......@@ -19,6 +19,15 @@ do
python3 model_prune_torch.py --pruner_name $name --pretrain_epochs 1 --prune_epochs 1
done
echo 'testing level pruner pruning'
python3 model_prune_torch.py --pruner_name level --pretrain_epochs 1 --prune_epochs 1
echo 'testing agp pruning'
python3 model_prune_torch.py --pruner_name agp --pretrain_epochs 1 --prune_epochs 2
echo 'testing mean_activation pruning'
python3 model_prune_torch.py --pruner_name mean_activation --pretrain_epochs 1 --prune_epochs 1
#echo "testing lottery ticket pruning..."
#python3 lottery_torch_mnist_fc.py
......
......@@ -28,9 +28,10 @@ cd $EXAMPLE_DIR/enas
python3 search.py --search-for macro --epochs 1
python3 search.py --search-for micro --epochs 1
echo "testing naive..."
cd $EXAMPLE_DIR/naive
python3 train.py
#disabled for now
#echo "testing naive..."
#cd $EXAMPLE_DIR/naive
#python3 train.py
echo "testing pdarts..."
cd $EXAMPLE_DIR/pdarts
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment