"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3be0ff90562e4b1419d41fe9ec049e32d0ca3e4f"
Commit 09bc6113 authored by Konstantinos Bousmalis's avatar Konstantinos Bousmalis
Browse files

DSN Updates

parent 0dbc90d4
...@@ -4,17 +4,10 @@ ...@@ -4,17 +4,10 @@
## Introduction ## Introduction
This code is the code used for the "Domain Separation Networks" paper This code is the code used for the "Domain Separation Networks" paper
by Bousmalis K., Trigeorgis G., et al. which was presented at NIPS 2016. The by Bousmalis K., Trigeorgis G., et al. which was presented at NIPS 2016. The
<<<<<<< HEAD
paper can be found here: https://arxiv.org/abs/1608.06019
## Contact
This code was open-sourced by Konstantinos Bousmalis (konstantinos@google.com, github:bousmalis)
=======
paper can be found here: https://arxiv.org/abs/1608.06019. paper can be found here: https://arxiv.org/abs/1608.06019.
## Contact ## Contact
This code was open-sourced by [Konstantinos Bousmalis](https://github.com/bousmalis) (konstantinos@google.com). This code was open-sourced by [Konstantinos Bousmalis](https://github.com/bousmalis) (konstantinos@google.com).
>>>>>>> d6bee2c713c6aed6522ab32c34b57412d0216d95
## Installation ## Installation
You will need to have the following installed on your machine before trying out the DSN code. You will need to have the following installed on your machine before trying out the DSN code.
...@@ -26,35 +19,27 @@ You will need to have the following installed on your machine before trying out ...@@ -26,35 +19,27 @@ You will need to have the following installed on your machine before trying out
Although we are making the code available, you are only able to use the MNIST Although we are making the code available, you are only able to use the MNIST
provider for now. We will soon provide a script to download and convert MNIST-M provider for now. We will soon provide a script to download and convert MNIST-M
as well. Check back here in a few weeks or wait for a relevant announcement from as well. Check back here in a few weeks or wait for a relevant announcement from
<<<<<<< HEAD
Twitter @bousmalis.
=======
[@bousmalis](https://twitter.com/bousmalis). [@bousmalis](https://twitter.com/bousmalis).
>>>>>>> d6bee2c713c6aed6522ab32c34b57412d0216d95
## Running the code for adapting MNIST to MNIST-M ## Running the code for adapting MNIST to MNIST-M
In order to run the MNIST to MNIST-M experiments with DANNs and/or DANNs with In order to run the MNIST to MNIST-M experiments with DANNs and/or DANNs with
domain separation (DSNs) you will need to set the directory you used to download domain separation (DSNs) you will need to set the directory you used to download
<<<<<<< HEAD
MNIST and MNIST-M:\
=======
MNIST and MNIST-M: MNIST and MNIST-M:
>>>>>>> d6bee2c713c6aed6522ab32c34b57412d0216d95
``` ```
$ export DSN_DATA_DIR=/your/dir $ export DSN_DATA_DIR=/your/dir
``` ```
Then you need to build the binaries with Bazel: Add models and models/slim to your `$PYTHONPATH`:
``` ```
$ bazel build -c opt domain_adaptation/domain_separation/... $ export PYTHONPATH=$PYTHONPATH:$PWD:$PWD/slim
``` ```
Add models and models/slim to your `$PYTHONPATH`: Then you need to build the binaries with Bazel:
``` ```
$ export PYTHONPATH=$PYTHONPATH:$PWD:$PWD/slim $ bazel build -c opt domain_adaptation/domain_separation/...
``` ```
You can then train with the following command: You can then train with the following command:
......
...@@ -14,22 +14,7 @@ ...@@ -14,22 +14,7 @@
# ============================================================================== # ==============================================================================
# pylint: disable=line-too-long # pylint: disable=line-too-long
r"""Evaluation for Domain Separation Networks (DSNs). """Evaluation for Domain Separation Networks (DSNs)."""
To build locally for CPU:
blaze build -c opt --copt=-mavx \
third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
To build locally for GPU:
blaze build -c opt --copt=-mavx --config=cuda_clang \
third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
To run locally:
$
./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval
\
--alsologtostderr
"""
# pylint: enable=line-too-long # pylint: enable=line-too-long
import math import math
......
...@@ -13,30 +13,7 @@ ...@@ -13,30 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
# pylint: disable=line-too-long """Training for Domain Separation Networks (DSNs)."""
r"""Training for Domain Separation Networks (DSNs).
-- Compile:
$ blaze build -c opt --copt=-mavx --config=cuda \
third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_train
-- Run:
$
./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_train
\
--similarity_loss=dann \
--basic_tower=dsn_cropped_linemod \
--source_dataset=pose_synthetic \
--target_dataset=pose_real \
--learning_rate=0.012 \
--alpha_weight=0.26 \
--gamma_weight=0.0115 \
--weight_decay=4e-5 \
--layers_to_regularize=fc3 \
--use_separation \
--alsologtostderr
"""
# pylint: enable=line-too-long
from __future__ import division from __future__ import division
import tensorflow as tf import tensorflow as tf
...@@ -59,7 +36,7 @@ tf.app.flags.DEFINE_string('target_dataset', 'pose_real', ...@@ -59,7 +36,7 @@ tf.app.flags.DEFINE_string('target_dataset', 'pose_real',
tf.app.flags.DEFINE_string('target_labeled_dataset', 'none', tf.app.flags.DEFINE_string('target_labeled_dataset', 'none',
'Target dataset to train on.') 'Target dataset to train on.')
tf.app.flags.DEFINE_string('dataset_dir', '/cns/ok-d/home/konstantinos/cad_learning/', tf.app.flags.DEFINE_string('dataset_dir', None,
'The directory where the dataset files are stored.') 'The directory where the dataset files are stored.')
tf.app.flags.DEFINE_string('master', '', tf.app.flags.DEFINE_string('master', '',
......
...@@ -178,16 +178,14 @@ def dann_loss(source_samples, target_samples, weight, scope=None): ...@@ -178,16 +178,14 @@ def dann_loss(source_samples, target_samples, weight, scope=None):
assert_op = tf.Assert(tf.is_finite(domain_loss), [domain_loss]) assert_op = tf.Assert(tf.is_finite(domain_loss), [domain_loss])
with tf.control_dependencies([assert_op]): with tf.control_dependencies([assert_op]):
tag_loss = 'losses/Domain Loss' tag_loss = 'losses/domain_loss'
tag_accuracy = 'losses/Domain Accuracy' tag_accuracy = 'losses/domain_accuracy'
if scope: if scope:
tag_loss = scope + tag_loss tag_loss = scope + tag_loss
tag_accuracy = scope + tag_accuracy tag_accuracy = scope + tag_accuracy
tf.summary.scalar( tf.summary.scalar(tag_loss, domain_loss)
tag_loss, domain_loss, name='domain_loss_summary') tf.summary.scalar(tag_accuracy, domain_accuracy)
tf.summary.scalar(
tag_accuracy, domain_accuracy, name='domain_accuracy_summary')
return domain_loss return domain_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment