Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bf748370
Commit
bf748370
authored
Aug 23, 2019
by
Nimit Nigania
Browse files
Merge remote-tracking branch 'upstream/master'
parents
7c732da7
0d2c2e01
Changes
92
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
533 additions
and
215 deletions
+533
-215
official/transformer/v2/transformer_benchmark.py
official/transformer/v2/transformer_benchmark.py
+27
-29
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+234
-63
official/transformer/v2/transformer_main_test.py
official/transformer/v2/transformer_main_test.py
+20
-5
official/transformer/v2/translate.py
official/transformer/v2/translate.py
+63
-19
official/utils/flags/_distribution.py
official/utils/flags/_distribution.py
+54
-0
official/utils/flags/_performance.py
official/utils/flags/_performance.py
+2
-2
official/utils/flags/core.py
official/utils/flags/core.py
+3
-0
official/utils/flags/flags_test.py
official/utils/flags/flags_test.py
+3
-1
official/utils/misc/distribution_utils.py
official/utils/misc/distribution_utils.py
+56
-17
official/utils/testing/integration.py
official/utils/testing/integration.py
+1
-5
official/utils/testing/pylint.rcfile
official/utils/testing/pylint.rcfile
+1
-1
official/vision/image_classification/README.md
official/vision/image_classification/README.md
+33
-33
official/vision/image_classification/__init__.py
official/vision/image_classification/__init__.py
+0
-0
official/vision/image_classification/cifar_preprocessing.py
official/vision/image_classification/cifar_preprocessing.py
+1
-1
official/vision/image_classification/common.py
official/vision/image_classification/common.py
+3
-6
official/vision/image_classification/common_test.py
official/vision/image_classification/common_test.py
+7
-7
official/vision/image_classification/imagenet_preprocessing.py
...ial/vision/image_classification/imagenet_preprocessing.py
+0
-0
official/vision/image_classification/resnet_cifar_main.py
official/vision/image_classification/resnet_cifar_main.py
+11
-11
official/vision/image_classification/resnet_cifar_model.py
official/vision/image_classification/resnet_cifar_model.py
+0
-0
official/vision/image_classification/resnet_cifar_test.py
official/vision/image_classification/resnet_cifar_test.py
+14
-15
No files found.
official/transformer/v2/transformer_benchmark.py
View file @
bf748370
...
@@ -21,6 +21,7 @@ import os
...
@@ -21,6 +21,7 @@ import os
import
time
import
time
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
from
official.transformer.v2
import
misc
from
official.transformer.v2
import
misc
from
official.transformer.v2
import
transformer_main
as
transformer_main
from
official.transformer.v2
import
transformer_main
as
transformer_main
...
@@ -30,6 +31,7 @@ from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
...
@@ -30,6 +31,7 @@ from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
TRANSFORMER_EN2DE_DATA_DIR_NAME
=
'wmt32k-en2de-official'
TRANSFORMER_EN2DE_DATA_DIR_NAME
=
'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME
=
'newstest2014'
EN2DE_2014_BLEU_DATA_DIR_NAME
=
'newstest2014'
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
TMP_DIR
=
os
.
getenv
(
'TMPDIR'
)
class
TransformerBenchmark
(
PerfZeroBenchmark
):
class
TransformerBenchmark
(
PerfZeroBenchmark
):
...
@@ -56,6 +58,11 @@ class TransformerBenchmark(PerfZeroBenchmark):
...
@@ -56,6 +58,11 @@ class TransformerBenchmark(PerfZeroBenchmark):
EN2DE_2014_BLEU_DATA_DIR_NAME
,
EN2DE_2014_BLEU_DATA_DIR_NAME
,
'newstest2014.de'
)
'newstest2014.de'
)
default_flags
[
'train_steps'
]
=
200
default_flags
[
'log_steps'
]
=
10
default_flags
[
'data_dir'
]
=
self
.
train_data_dir
default_flags
[
'vocab_file'
]
=
self
.
vocab_file
super
(
TransformerBenchmark
,
self
).
__init__
(
super
(
TransformerBenchmark
,
self
).
__init__
(
output_dir
=
output_dir
,
output_dir
=
output_dir
,
default_flags
=
default_flags
,
default_flags
=
default_flags
,
...
@@ -280,8 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
...
@@ -280,8 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu'
)
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu'
)
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
log_steps
=
FLAGS
.
log_steps
,
log_steps
=
FLAGS
.
log_steps
,
bleu_min
=
2
8
,
bleu_min
=
2
7.9
,
bleu_max
=
29
)
bleu_max
=
29
.2
)
def
benchmark_8_gpu_static_batch
(
self
):
def
benchmark_8_gpu_static_batch
(
self
):
"""Benchmark 8 gpu.
"""Benchmark 8 gpu.
...
@@ -305,12 +312,19 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
...
@@ -305,12 +312,19 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
log_steps
=
FLAGS
.
log_steps
,
log_steps
=
FLAGS
.
log_steps
,
bleu_min
=
28
,
bleu_min
=
28
,
bleu_max
=
29
)
bleu_max
=
29
.2
)
def
benchmark_8_gpu_fp16
(
self
):
def
benchmark_8_gpu_fp16
(
self
):
"""Benchmark 8 gpu with dynamic batch and fp16.
"""Benchmark 8 gpu with dynamic batch and fp16.
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
Over 6 runs with eval every 20K steps the average highest value was 28.247
(bleu uncased). 28.424 was the highest and 28.09 the lowest. The values are
the highest value seen during a run and occurred at a median of iteration
11. While this could be interpreted as worse than FP32, if looking at the
first iteration at which 28 is passed FP16 performs equal and possibly
better. Although not part of the initial test runs, the highest value
recorded with the arguments below was 28.9 at iteration 12. Iterations are
not epochs, an iteration is a number of steps between evals.
"""
"""
self
.
_setup
()
self
.
_setup
()
FLAGS
.
num_gpus
=
8
FLAGS
.
num_gpus
=
8
...
@@ -328,7 +342,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
...
@@ -328,7 +342,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
log_steps
=
FLAGS
.
log_steps
,
log_steps
=
FLAGS
.
log_steps
,
bleu_min
=
28
,
bleu_min
=
28
,
bleu_max
=
29
)
bleu_max
=
29
.2
)
def
benchmark_8_gpu_static_batch_fp16
(
self
):
def
benchmark_8_gpu_static_batch_fp16
(
self
):
"""Benchmark 8 gpu with static batch and fp16.
"""Benchmark 8 gpu with static batch and fp16.
...
@@ -353,7 +367,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
...
@@ -353,7 +367,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
log_steps
=
FLAGS
.
log_steps
,
log_steps
=
FLAGS
.
log_steps
,
bleu_min
=
28
,
bleu_min
=
28
,
bleu_max
=
29
)
bleu_max
=
29
.2
)
def
benchmark_xla_8_gpu_static_batch_fp16
(
self
):
def
benchmark_xla_8_gpu_static_batch_fp16
(
self
):
"""Benchmark 8 gpu with static batch, XLA, and FP16.
"""Benchmark 8 gpu with static batch, XLA, and FP16.
...
@@ -380,7 +394,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
...
@@ -380,7 +394,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
self
.
_run_and_report_benchmark
(
total_batch_size
=
FLAGS
.
batch_size
,
log_steps
=
FLAGS
.
log_steps
,
log_steps
=
FLAGS
.
log_steps
,
bleu_min
=
28
,
bleu_min
=
28
,
bleu_max
=
29
)
bleu_max
=
29
.2
)
class
TransformerKerasBenchmark
(
TransformerBenchmark
):
class
TransformerKerasBenchmark
(
TransformerBenchmark
):
...
@@ -611,19 +625,9 @@ class TransformerKerasBenchmark(TransformerBenchmark):
...
@@ -611,19 +625,9 @@ class TransformerKerasBenchmark(TransformerBenchmark):
class
TransformerBaseKerasBenchmarkReal
(
TransformerKerasBenchmark
):
class
TransformerBaseKerasBenchmarkReal
(
TransformerKerasBenchmark
):
"""Transformer based version real data benchmark tests."""
"""Transformer based version real data benchmark tests."""
def
__init__
(
self
,
output_dir
=
None
,
root_data_dir
=
None
,
**
kwargs
):
def
__init__
(
self
,
output_dir
=
TMP_DIR
,
root_data_dir
=
None
,
**
kwargs
):
train_data_dir
=
os
.
path
.
join
(
root_data_dir
,
TRANSFORMER_EN2DE_DATA_DIR_NAME
)
vocab_file
=
os
.
path
.
join
(
root_data_dir
,
TRANSFORMER_EN2DE_DATA_DIR_NAME
,
'vocab.ende.32768'
)
def_flags
=
{}
def_flags
=
{}
def_flags
[
'param_set'
]
=
'base'
def_flags
[
'param_set'
]
=
'base'
def_flags
[
'vocab_file'
]
=
vocab_file
def_flags
[
'data_dir'
]
=
train_data_dir
def_flags
[
'train_steps'
]
=
200
def_flags
[
'log_steps'
]
=
10
super
(
TransformerBaseKerasBenchmarkReal
,
self
).
__init__
(
super
(
TransformerBaseKerasBenchmarkReal
,
self
).
__init__
(
output_dir
=
output_dir
,
default_flags
=
def_flags
,
output_dir
=
output_dir
,
default_flags
=
def_flags
,
...
@@ -633,20 +637,14 @@ class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
...
@@ -633,20 +637,14 @@ class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
class
TransformerBigKerasBenchmarkReal
(
TransformerKerasBenchmark
):
class
TransformerBigKerasBenchmarkReal
(
TransformerKerasBenchmark
):
"""Transformer based version real data benchmark tests."""
"""Transformer based version real data benchmark tests."""
def
__init__
(
self
,
output_dir
=
None
,
root_data_dir
=
None
,
**
kwargs
):
def
__init__
(
self
,
output_dir
=
TMP_DIR
,
root_data_dir
=
None
,
**
kwargs
):
train_data_dir
=
os
.
path
.
join
(
root_data_dir
,
TRANSFORMER_EN2DE_DATA_DIR_NAME
)
vocab_file
=
os
.
path
.
join
(
root_data_dir
,
TRANSFORMER_EN2DE_DATA_DIR_NAME
,
'vocab.ende.32768'
)
def_flags
=
{}
def_flags
=
{}
def_flags
[
'param_set'
]
=
'big'
def_flags
[
'param_set'
]
=
'big'
def_flags
[
'vocab_file'
]
=
vocab_file
def_flags
[
'data_dir'
]
=
train_data_dir
def_flags
[
'train_steps'
]
=
200
def_flags
[
'log_steps'
]
=
10
super
(
TransformerBigKerasBenchmarkReal
,
self
).
__init__
(
super
(
TransformerBigKerasBenchmarkReal
,
self
).
__init__
(
output_dir
=
output_dir
,
default_flags
=
def_flags
,
output_dir
=
output_dir
,
default_flags
=
def_flags
,
root_data_dir
=
root_data_dir
,
batch_per_gpu
=
3072
)
root_data_dir
=
root_data_dir
,
batch_per_gpu
=
3072
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/transformer/v2/transformer_main.py
View file @
bf748370
...
@@ -27,12 +27,16 @@ import tempfile
...
@@ -27,12 +27,16 @@ import tempfile
from
absl
import
app
as
absl_app
# pylint: disable=unused-import
from
absl
import
app
as
absl_app
# pylint: disable=unused-import
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.util
import
object_identity
# pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
from
official.transformer
import
compute_bleu
from
official.transformer
import
compute_bleu
from
official.transformer.utils
import
tokenizer
from
official.transformer.utils
import
tokenizer
from
official.transformer.v2
import
data_pipeline
from
official.transformer.v2
import
data_pipeline
from
official.transformer.v2
import
metrics
from
official.transformer.v2
import
misc
from
official.transformer.v2
import
misc
from
official.transformer.v2
import
optimizer
from
official.transformer.v2
import
optimizer
from
official.transformer.v2
import
transformer
from
official.transformer.v2
import
transformer
...
@@ -48,18 +52,40 @@ BLEU_DIR = "bleu"
...
@@ -48,18 +52,40 @@ BLEU_DIR = "bleu"
_SINGLE_SAMPLE
=
1
_SINGLE_SAMPLE
=
1
def
translate_and_compute_bleu
(
model
,
subtokenizer
,
bleu_source
,
bleu_ref
):
def
translate_and_compute_bleu
(
model
,
"""Translate file and report the cased and uncased bleu scores."""
params
,
subtokenizer
,
bleu_source
,
bleu_ref
,
distribution_strategy
=
None
):
"""Translate file and report the cased and uncased bleu scores.
Args:
model: A Keras model, used to generate the translations.
params: A dictionary, containing the translation related parameters.
subtokenizer: A subtokenizer object, used for encoding and decoding source
and translated lines.
bleu_source: A file containing source sentences for translation.
bleu_ref: A file containing the reference for the translated sentences.
distribution_strategy: A platform distribution strategy, used for TPU based
translation.
Returns:
uncased_score: A float, the case insensitive BLEU score.
cased_score: A float, the case sensitive BLEU score.
"""
# Create temporary file to store translation.
# Create temporary file to store translation.
tmp
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
tmp
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
tmp_filename
=
tmp
.
name
tmp_filename
=
tmp
.
name
translate
.
translate_file
(
translate
.
translate_file
(
model
,
model
,
params
,
subtokenizer
,
subtokenizer
,
bleu_source
,
bleu_source
,
output_file
=
tmp_filename
,
output_file
=
tmp_filename
,
print_all_translations
=
False
)
print_all_translations
=
False
,
distribution_strategy
=
distribution_strategy
)
# Compute uncased and cased bleu scores.
# Compute uncased and cased bleu scores.
uncased_score
=
compute_bleu
.
bleu_wrapper
(
bleu_ref
,
tmp_filename
,
False
)
uncased_score
=
compute_bleu
.
bleu_wrapper
(
bleu_ref
,
tmp_filename
,
False
)
...
@@ -68,15 +94,34 @@ def translate_and_compute_bleu(model, subtokenizer, bleu_source, bleu_ref):
...
@@ -68,15 +94,34 @@ def translate_and_compute_bleu(model, subtokenizer, bleu_source, bleu_ref):
return
uncased_score
,
cased_score
return
uncased_score
,
cased_score
def
evaluate_and_log_bleu
(
model
,
bleu_source
,
bleu_ref
,
vocab_file
):
def
evaluate_and_log_bleu
(
model
,
"""Calculate and record the BLEU score."""
params
,
bleu_source
,
bleu_ref
,
vocab_file
,
distribution_strategy
=
None
):
"""Calculate and record the BLEU score.
Args:
model: A Keras model, used to generate the translations.
params: A dictionary, containing the translation related parameters.
bleu_source: A file containing source sentences for translation.
bleu_ref: A file containing the reference for the translated sentences.
vocab_file: A file containing the vocabulary for translation.
distribution_strategy: A platform distribution strategy, used for TPU based
translation.
Returns:
uncased_score: A float, the case insensitive BLEU score.
cased_score: A float, the case sensitive BLEU score.
"""
subtokenizer
=
tokenizer
.
Subtokenizer
(
vocab_file
)
subtokenizer
=
tokenizer
.
Subtokenizer
(
vocab_file
)
uncased_score
,
cased_score
=
translate_and_compute_bleu
(
uncased_score
,
cased_score
=
translate_and_compute_bleu
(
model
,
subtokenizer
,
bleu_source
,
bleu_ref
)
model
,
params
,
subtokenizer
,
bleu_source
,
bleu_ref
,
distribution_strategy
)
tf
.
compat
.
v1
.
logging
.
info
(
"Bleu score (uncased): %s"
,
uncased_score
)
logging
.
info
(
"Bleu score (uncased): %s"
,
uncased_score
)
tf
.
compat
.
v1
.
logging
.
info
(
"Bleu score (cased): %s"
,
cased_score
)
logging
.
info
(
"Bleu score (cased): %s"
,
cased_score
)
return
uncased_score
,
cased_score
return
uncased_score
,
cased_score
...
@@ -88,30 +133,27 @@ class TransformerTask(object):
...
@@ -88,30 +133,27 @@ class TransformerTask(object):
Args:
Args:
flags_obj: Object containing parsed flag values, i.e., FLAGS.
flags_obj: Object containing parsed flag values, i.e., FLAGS.
Raises:
ValueError: if not using static batch for input data on TPU.
"""
"""
self
.
flags_obj
=
flags_obj
self
.
flags_obj
=
flags_obj
self
.
predict_model
=
None
self
.
predict_model
=
None
# Add flag-defined parameters to params object
# Add flag-defined parameters to params object
num_gpus
=
flags_core
.
get_num_gpus
(
flags_obj
)
num_gpus
=
flags_core
.
get_num_gpus
(
flags_obj
)
self
.
distribution_strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
flags_obj
.
distribution_strategy
,
num_gpus
=
flags_core
.
get_num_gpus
(
flags_obj
))
print
(
"Running transformer with num_gpus ="
,
num_gpus
)
if
self
.
distribution_strategy
:
print
(
"For training, using distribution strategy: "
,
self
.
distribution_strategy
)
else
:
print
(
"Not using any distribution strategy."
)
self
.
params
=
params
=
misc
.
get_model_params
(
flags_obj
.
param_set
,
num_gpus
)
self
.
params
=
params
=
misc
.
get_model_params
(
flags_obj
.
param_set
,
num_gpus
)
params
[
"num_gpus"
]
=
num_gpus
params
[
"num_gpus"
]
=
num_gpus
params
[
"use_ctl"
]
=
flags_obj
.
use_ctl
params
[
"is_tpu_pod"
]
=
flags_obj
.
is_tpu_pod
params
[
"data_dir"
]
=
flags_obj
.
data_dir
params
[
"data_dir"
]
=
flags_obj
.
data_dir
params
[
"model_dir"
]
=
flags_obj
.
model_dir
params
[
"model_dir"
]
=
flags_obj
.
model_dir
params
[
"static_batch"
]
=
flags_obj
.
static_batch
params
[
"static_batch"
]
=
flags_obj
.
static_batch
params
[
"max_length"
]
=
flags_obj
.
max_length
params
[
"max_length"
]
=
flags_obj
.
max_length
params
[
"decode_batch_size"
]
=
flags_obj
.
decode_batch_size
params
[
"decode_max_length"
]
=
flags_obj
.
decode_max_length
params
[
"padded_decode"
]
=
flags_obj
.
padded_decode
params
[
"num_parallel_calls"
]
=
(
params
[
"num_parallel_calls"
]
=
(
flags_obj
.
num_parallel_calls
or
tf
.
data
.
experimental
.
AUTOTUNE
)
flags_obj
.
num_parallel_calls
or
tf
.
data
.
experimental
.
AUTOTUNE
)
...
@@ -130,33 +172,114 @@ class TransformerTask(object):
...
@@ -130,33 +172,114 @@ class TransformerTask(object):
"infer_float32_vars"
)
"infer_float32_vars"
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
self
.
distribution_strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
flags_obj
.
distribution_strategy
,
num_gpus
=
num_gpus
,
tpu_address
=
flags_obj
.
tpu
or
""
)
if
self
.
use_tpu
:
params
[
"num_replicas"
]
=
self
.
distribution_strategy
.
num_replicas_in_sync
if
not
params
[
"static_batch"
]:
raise
ValueError
(
"TPU requires static batch for input data."
)
else
:
print
(
"Running transformer with num_gpus ="
,
num_gpus
)
if
self
.
distribution_strategy
:
print
(
"For training, using distribution strategy: "
,
self
.
distribution_strategy
)
else
:
print
(
"Not using any distribution strategy."
)
@
property
def
use_tpu
(
self
):
if
self
.
distribution_strategy
:
return
isinstance
(
self
.
distribution_strategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
)
return
False
def
train
(
self
):
def
train
(
self
):
"""Trains the model."""
"""Trains the model."""
params
,
flags_obj
,
is_train
=
self
.
params
,
self
.
flags_obj
,
True
params
=
self
.
params
flags_obj
=
self
.
flags_obj
# Sets config options.
# Sets config options.
keras_utils
.
set_session_config
(
keras_utils
.
set_session_config
(
enable_xla
=
flags_obj
.
enable_xla
)
enable_xla
=
flags_obj
.
enable_xla
)
_ensure_dir
(
flags_obj
.
model_dir
)
_ensure_dir
(
flags_obj
.
model_dir
)
if
self
.
distribution_strategy
:
with
distribution_utils
.
get_strategy_scope
(
self
.
distribution_strategy
):
with
self
.
distribution_strategy
.
scope
():
model
=
transformer
.
create_model
(
params
,
is_train
=
True
)
model
=
transformer
.
create_model
(
params
,
is_train
)
opt
=
self
.
_create_optimizer
()
model
.
compile
(
opt
)
else
:
model
=
transformer
.
create_model
(
params
,
is_train
)
opt
=
self
.
_create_optimizer
()
opt
=
self
.
_create_optimizer
()
model
.
compile
(
opt
)
if
params
[
"use_ctl"
]:
train_loss_metric
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
)
else
:
model
.
compile
(
opt
)
model
.
summary
()
model
.
summary
()
train_ds
=
data_pipeline
.
train_input_fn
(
params
)
if
self
.
use_tpu
:
map_data_fn
=
data_pipeline
.
map_data_for_transformer_fn
# Different from experimental_distribute_dataset,
train_ds
=
train_ds
.
map
(
map_data_fn
,
# experimental_distribute_datasets_from_function requires
num_parallel_calls
=
params
[
"num_parallel_calls"
])
# per-replica/local batch size.
params
[
"batch_size"
]
/=
self
.
distribution_strategy
.
num_replicas_in_sync
train_ds
=
(
self
.
distribution_strategy
.
experimental_distribute_datasets_from_function
(
lambda
ctx
:
data_pipeline
.
train_input_fn
(
params
)))
else
:
train_ds
=
data_pipeline
.
train_input_fn
(
params
)
map_data_fn
=
data_pipeline
.
map_data_for_transformer_fn
train_ds
=
train_ds
.
map
(
map_data_fn
,
num_parallel_calls
=
params
[
"num_parallel_calls"
])
if
params
[
"use_ctl"
]:
train_ds_iterator
=
iter
(
train_ds
)
callbacks
=
self
.
_create_callbacks
(
flags_obj
.
model_dir
,
0
,
params
)
callbacks
=
self
.
_create_callbacks
(
flags_obj
.
model_dir
,
0
,
params
)
# TODO(b/139418525): Refactor the custom training loop logic.
@
tf
.
function
def
train_steps
(
iterator
,
steps
):
"""Training steps function for TPU runs.
Args:
iterator: The input iterator of the training dataset.
steps: An integer, the number of training steps.
Returns:
A float, the loss value.
"""
def
_step_fn
(
inputs
):
"""Per-replica step function."""
inputs
,
targets
=
inputs
with
tf
.
GradientTape
()
as
tape
:
logits
=
model
([
inputs
,
targets
],
training
=
True
)
loss
=
metrics
.
transformer_loss
(
logits
,
targets
,
params
[
"label_smoothing"
],
params
[
"vocab_size"
])
# Scales the loss, which results in using the average loss across all
# of the replicas for backprop.
scaled_loss
=
loss
/
self
.
distribution_strategy
.
num_replicas_in_sync
# De-dupes variables due to keras tracking issues.
tvars
=
list
(
object_identity
.
ObjectIdentitySet
(
model
.
trainable_variables
))
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
opt
.
apply_gradients
(
zip
(
grads
,
tvars
))
# For reporting, the metric takes the mean of losses.
train_loss_metric
.
update_state
(
loss
)
for
_
in
tf
.
range
(
steps
):
train_loss_metric
.
reset_states
()
self
.
distribution_strategy
.
experimental_run_v2
(
_step_fn
,
args
=
(
next
(
iterator
),))
if
self
.
use_tpu
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
,
optimizer
=
opt
)
latest_checkpoint
=
tf
.
train
.
latest_checkpoint
(
flags_obj
.
model_dir
)
if
latest_checkpoint
:
checkpoint
.
restore
(
latest_checkpoint
)
logging
.
info
(
"Loaded checkpoint %s"
,
latest_checkpoint
)
if
flags_obj
.
train_steps
<
flags_obj
.
steps_between_evals
:
if
flags_obj
.
train_steps
<
flags_obj
.
steps_between_evals
:
flags_obj
.
steps_between_evals
=
flags_obj
.
train_steps
flags_obj
.
steps_between_evals
=
flags_obj
.
train_steps
iterations
=
flags_obj
.
train_steps
//
flags_obj
.
steps_between_evals
iterations
=
flags_obj
.
train_steps
//
flags_obj
.
steps_between_evals
...
@@ -165,28 +288,54 @@ class TransformerTask(object):
...
@@ -165,28 +288,54 @@ class TransformerTask(object):
cased_score_history
,
uncased_score_history
=
[],
[]
cased_score_history
,
uncased_score_history
=
[],
[]
for
i
in
range
(
1
,
iterations
+
1
):
for
i
in
range
(
1
,
iterations
+
1
):
print
(
"Start train iteration:{}/{}"
.
format
(
i
,
iterations
))
print
(
"Start train iteration:{}/{}"
.
format
(
i
,
iterations
))
history
=
model
.
fit
(
history
=
None
train_ds
,
if
params
[
"use_ctl"
]:
initial_epoch
=
i
-
1
,
if
not
self
.
use_tpu
:
epochs
=
i
,
raise
NotImplementedError
(
steps_per_epoch
=
flags_obj
.
steps_between_evals
,
"Custom training loop on GPUs is not implemented."
)
callbacks
=
callbacks
,
train_steps_per_eval
=
tf
.
convert_to_tensor
(
# If TimeHistory is enabled, progress bar would be messy. Increase the
flags_obj
.
steps_between_evals
,
dtype
=
tf
.
int32
)
# verbose level to get rid of it.
verbose
=
(
2
if
flags_obj
.
enable_time_history
else
1
))
# Runs training steps.
train_steps
(
train_ds_iterator
,
train_steps_per_eval
)
train_loss
=
train_loss_metric
.
result
().
numpy
().
astype
(
float
)
logging
.
info
(
"Train Step: %d/%d / loss = %s"
,
i
*
flags_obj
.
steps_between_evals
,
flags_obj
.
train_steps
,
train_loss
)
checkpoint_name
=
checkpoint
.
save
(
os
.
path
.
join
(
flags_obj
.
model_dir
,
"ctl_step_{}.ckpt"
.
format
(
i
*
flags_obj
.
steps_between_evals
)))
logging
.
info
(
"Saved checkpoint to %s"
,
checkpoint_name
)
else
:
if
self
.
use_tpu
:
raise
NotImplementedError
(
"Keras model.fit on TPUs is not implemented."
)
history
=
model
.
fit
(
train_ds
,
initial_epoch
=
i
-
1
,
epochs
=
i
,
steps_per_epoch
=
flags_obj
.
steps_between_evals
,
callbacks
=
callbacks
,
# If TimeHistory is enabled, progress bar would be messy. Increase
# the verbose level to get rid of it.
verbose
=
(
2
if
flags_obj
.
enable_time_history
else
1
))
logging
.
info
(
"Train history: {}"
.
format
(
history
.
history
))
print
(
"End train iteration:{}/{} global step:{}"
.
format
(
print
(
"End train iteration:{}/{} global step:{}"
.
format
(
i
,
i
,
iterations
,
iterations
,
i
*
flags_obj
.
steps_between_evals
))
i
*
flags_obj
.
steps_between_evals
))
tf
.
compat
.
v1
.
logging
.
info
(
"Train history: {}"
.
format
(
history
.
history
))
stats
=
misc
.
build_stats
(
history
,
callbacks
)
if
(
flags_obj
.
bleu_source
and
flags_obj
.
bleu_ref
):
if
(
flags_obj
.
bleu_source
and
flags_obj
.
bleu_ref
):
uncased_score
,
cased_score
=
self
.
eval
()
uncased_score
,
cased_score
=
self
.
eval
()
cased_score_history
.
append
([
i
,
cased_score
])
cased_score_history
.
append
([
i
,
cased_score
])
uncased_score_history
.
append
([
i
,
uncased_score
])
uncased_score_history
.
append
([
i
,
uncased_score
])
stats
=
misc
.
build_stats
(
history
,
callbacks
)
stats
=
({
"loss"
:
train_loss
}
if
history
is
None
else
misc
.
build_stats
(
history
,
callbacks
))
if
uncased_score
and
cased_score
:
if
uncased_score
and
cased_score
:
stats
[
"bleu_uncased"
]
=
uncased_score
stats
[
"bleu_uncased"
]
=
uncased_score
stats
[
"bleu_cased"
]
=
cased_score
stats
[
"bleu_cased"
]
=
cased_score
...
@@ -202,17 +351,18 @@ class TransformerTask(object):
...
@@ -202,17 +351,18 @@ class TransformerTask(object):
self
.
predict_model
,
self
.
predict_model
,
tf
.
train
.
latest_checkpoint
(
self
.
flags_obj
.
model_dir
))
tf
.
train
.
latest_checkpoint
(
self
.
flags_obj
.
model_dir
))
self
.
predict_model
.
summary
()
self
.
predict_model
.
summary
()
return
evaluate_and_log_bleu
(
self
.
predict_model
,
return
evaluate_and_log_bleu
(
self
.
flags_obj
.
bleu_source
,
self
.
predict_model
,
self
.
params
,
self
.
flags_obj
.
bleu_source
,
self
.
flags_obj
.
bleu_ref
,
self
.
flags_obj
.
bleu_ref
,
self
.
flags_obj
.
vocab_file
,
self
.
flags_obj
.
vocab_fil
e
)
self
.
distribution_strategy
if
self
.
use_tpu
else
Non
e
)
def
predict
(
self
):
def
predict
(
self
):
"""Predicts result from the model."""
"""Predicts result from the model."""
params
,
flags_obj
,
is_train
=
self
.
params
,
self
.
flags_obj
,
False
params
=
self
.
params
flags_obj
=
self
.
flags_obj
with
tf
.
name_scope
(
"model"
):
with
tf
.
name_scope
(
"model"
):
model
=
transformer
.
create_model
(
params
,
is_train
)
model
=
transformer
.
create_model
(
params
,
is_train
=
False
)
self
.
_load_weights_if_possible
(
self
.
_load_weights_if_possible
(
model
,
tf
.
train
.
latest_checkpoint
(
self
.
flags_obj
.
model_dir
))
model
,
tf
.
train
.
latest_checkpoint
(
self
.
flags_obj
.
model_dir
))
model
.
summary
()
model
.
summary
()
...
@@ -242,16 +392,28 @@ class TransformerTask(object):
...
@@ -242,16 +392,28 @@ class TransformerTask(object):
def
_load_weights_if_possible
(
self
,
model
,
init_weight_path
=
None
):
def
_load_weights_if_possible
(
self
,
model
,
init_weight_path
=
None
):
"""Loads model weights when it is provided."""
"""Loads model weights when it is provided."""
if
init_weight_path
:
if
init_weight_path
:
tf
.
compat
.
v1
.
logging
.
info
(
"Load weights: {}"
.
format
(
init_weight_path
))
logging
.
info
(
"Load weights: {}"
.
format
(
init_weight_path
))
model
.
load_weights
(
init_weight_path
)
# TODO(b/139414977): Having the same variable restoring method for both
# TPU and GPU.
if
self
.
use_tpu
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
,
optimizer
=
self
.
_create_optimizer
())
checkpoint
.
restore
(
init_weight_path
)
else
:
model
.
load_weights
(
init_weight_path
)
else
:
else
:
print
(
"Weights not loaded from path:{}"
.
format
(
init_weight_path
))
print
(
"Weights not loaded from path:{}"
.
format
(
init_weight_path
))
def
_create_optimizer
(
self
):
def
_create_optimizer
(
self
):
"""Creates optimizer."""
"""Creates optimizer."""
params
=
self
.
params
params
=
self
.
params
# TODO(b/139414679): Explore the difference between using
# LearningRateSchedule and callback for GPU runs, and try to merge them.
lr_schedule
=
optimizer
.
LearningRateSchedule
(
params
[
"learning_rate"
],
params
[
"hidden_size"
],
params
[
"learning_rate_warmup_steps"
])
opt
=
tf
.
keras
.
optimizers
.
Adam
(
opt
=
tf
.
keras
.
optimizers
.
Adam
(
params
[
"learning_rate"
],
lr_schedule
if
self
.
use_tpu
else
params
[
"learning_rate"
],
params
[
"optimizer_adam_beta1"
],
params
[
"optimizer_adam_beta1"
],
params
[
"optimizer_adam_beta2"
],
params
[
"optimizer_adam_beta2"
],
epsilon
=
params
[
"optimizer_adam_epsilon"
])
epsilon
=
params
[
"optimizer_adam_epsilon"
])
...
@@ -264,25 +426,34 @@ class TransformerTask(object):
...
@@ -264,25 +426,34 @@ class TransformerTask(object):
def
_ensure_dir
(
log_dir
):
def
_ensure_dir
(
log_dir
):
"""Makes log dir if not existed."""
"""Makes log dir if not existed."""
if
not
os
.
path
.
exists
(
log_dir
):
if
not
tf
.
io
.
gfile
.
exists
(
log_dir
):
os
.
makedirs
(
log_dir
)
tf
.
io
.
gfile
.
makedirs
(
log_dir
)
def
main
(
_
):
def
main
(
_
):
flags_obj
=
flags
.
FLAGS
flags_obj
=
flags
.
FLAGS
with
logger
.
benchmark_context
(
flags_obj
):
with
logger
.
benchmark_context
(
flags_obj
):
task
=
TransformerTask
(
flags_obj
)
task
=
TransformerTask
(
flags_obj
)
if
flags_obj
.
mode
==
"train"
:
task
.
train
()
def
_run_task
(
task
):
elif
flags_obj
.
mode
==
"predict"
:
if
flags_obj
.
mode
==
"train"
:
task
.
predict
()
task
.
train
()
elif
flags_obj
.
mode
==
"eval"
:
elif
flags_obj
.
mode
==
"predict"
:
task
.
eval
()
task
.
predict
()
elif
flags_obj
.
mode
==
"eval"
:
task
.
eval
()
else
:
raise
ValueError
(
"Invalid mode {}"
.
format
(
flags_obj
.
mode
))
if
not
flags_obj
.
distribution_strategy
!=
"tpu"
:
_run_task
(
task
)
else
:
else
:
raise
ValueError
(
"Invalid mode {}"
.
format
(
flags_obj
.
mode
))
primary_cpu_task
=
"/job:worker"
if
flags_obj
.
use_tpu_2vm_config
else
""
with
tf
.
device
(
primary_cpu_task
):
_run_task
(
task
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
misc
.
define_transformer_flags
()
misc
.
define_transformer_flags
()
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
official/transformer/v2/transformer_main_test.py
View file @
bf748370
...
@@ -30,7 +30,7 @@ from official.transformer.v2 import misc
...
@@ -30,7 +30,7 @@ from official.transformer.v2 import misc
from
official.transformer.v2
import
transformer_main
as
tm
from
official.transformer.v2
import
transformer_main
as
tm
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
tensorflow.python.eager
import
context
# pylint: disable=ungrouped-imports
from
tensorflow.python.eager
import
context
# pylint: disable=ungrouped-imports
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
FIXED_TIMESTAMP
=
'my_time_stamp'
FIXED_TIMESTAMP
=
'my_time_stamp'
...
@@ -80,11 +80,19 @@ class TransformerTaskTest(tf.test.TestCase):
...
@@ -80,11 +80,19 @@ class TransformerTaskTest(tf.test.TestCase):
self
.
assertTrue
(
os
.
path
.
exists
(
filepath
))
self
.
assertTrue
(
os
.
path
.
exists
(
filepath
))
def
test_train_no_dist_strat
(
self
):
def
test_train_no_dist_strat
(
self
):
if
context
.
num_gpus
()
>=
2
:
self
.
skipTest
(
'No need to test 2+ GPUs without a distribution strategy.'
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
train
()
t
.
train
()
def
test_train_static_batch
(
self
):
def
test_train_static_batch
(
self
):
if
context
.
num_gpus
()
>=
2
:
self
.
skipTest
(
'No need to test 2+ GPUs without a distribution strategy.'
)
FLAGS
.
distribution_strategy
=
'one_device'
FLAGS
.
distribution_strategy
=
'one_device'
if
tf
.
test
.
is_built_with_cuda
():
FLAGS
.
num_gpus
=
1
else
:
FLAGS
.
num_gpus
=
0
FLAGS
.
static_batch
=
True
FLAGS
.
static_batch
=
True
t
=
tm
.
TransformerTask
(
FLAGS
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
train
()
t
.
train
()
...
@@ -97,6 +105,7 @@ class TransformerTaskTest(tf.test.TestCase):
...
@@ -97,6 +105,7 @@ class TransformerTaskTest(tf.test.TestCase):
@
unittest
.
skipUnless
(
tf
.
test
.
is_built_with_cuda
(),
'requires GPU'
)
@
unittest
.
skipUnless
(
tf
.
test
.
is_built_with_cuda
(),
'requires GPU'
)
def
test_train_fp16
(
self
):
def
test_train_fp16
(
self
):
FLAGS
.
distribution_strategy
=
'one_device'
FLAGS
.
dtype
=
'fp16'
FLAGS
.
dtype
=
'fp16'
t
=
tm
.
TransformerTask
(
FLAGS
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
train
()
t
.
train
()
...
@@ -105,8 +114,8 @@ class TransformerTaskTest(tf.test.TestCase):
...
@@ -105,8 +114,8 @@ class TransformerTaskTest(tf.test.TestCase):
def
test_train_2_gpu
(
self
):
def
test_train_2_gpu
(
self
):
if
context
.
num_gpus
()
<
2
:
if
context
.
num_gpus
()
<
2
:
self
.
skipTest
(
self
.
skipTest
(
'{} GPUs are not available for this test. {} GPUs are available'
.
'{} GPUs are not available for this test. {} GPUs are available'
format
(
2
,
context
.
num_gpus
()))
.
format
(
2
,
context
.
num_gpus
()))
FLAGS
.
distribution_strategy
=
'mirrored'
FLAGS
.
distribution_strategy
=
'mirrored'
FLAGS
.
num_gpus
=
2
FLAGS
.
num_gpus
=
2
FLAGS
.
param_set
=
'base'
FLAGS
.
param_set
=
'base'
...
@@ -117,8 +126,8 @@ class TransformerTaskTest(tf.test.TestCase):
...
@@ -117,8 +126,8 @@ class TransformerTaskTest(tf.test.TestCase):
def
test_train_2_gpu_fp16
(
self
):
def
test_train_2_gpu_fp16
(
self
):
if
context
.
num_gpus
()
<
2
:
if
context
.
num_gpus
()
<
2
:
self
.
skipTest
(
self
.
skipTest
(
'{} GPUs are not available for this test. {} GPUs are available'
.
'{} GPUs are not available for this test. {} GPUs are available'
format
(
2
,
context
.
num_gpus
()))
.
format
(
2
,
context
.
num_gpus
()))
FLAGS
.
distribution_strategy
=
'mirrored'
FLAGS
.
distribution_strategy
=
'mirrored'
FLAGS
.
num_gpus
=
2
FLAGS
.
num_gpus
=
2
FLAGS
.
param_set
=
'base'
FLAGS
.
param_set
=
'base'
...
@@ -153,16 +162,22 @@ class TransformerTaskTest(tf.test.TestCase):
...
@@ -153,16 +162,22 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS
(
update_flags
)
FLAGS
(
update_flags
)
def
test_predict
(
self
):
def
test_predict
(
self
):
if
context
.
num_gpus
()
>=
2
:
self
.
skipTest
(
'No need to test 2+ GPUs without a distribution strategy.'
)
self
.
_prepare_files_and_flags
()
self
.
_prepare_files_and_flags
()
t
=
tm
.
TransformerTask
(
FLAGS
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
predict
()
t
.
predict
()
def
test_predict_fp16
(
self
):
def
test_predict_fp16
(
self
):
if
context
.
num_gpus
()
>=
2
:
self
.
skipTest
(
'No need to test 2+ GPUs without a distribution strategy.'
)
self
.
_prepare_files_and_flags
(
'--dtype=fp16'
)
self
.
_prepare_files_and_flags
(
'--dtype=fp16'
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
predict
()
t
.
predict
()
def
test_eval
(
self
):
def
test_eval
(
self
):
if
context
.
num_gpus
()
>=
2
:
self
.
skipTest
(
'No need to test 2+ GPUs without a distribution strategy.'
)
self
.
_prepare_files_and_flags
()
self
.
_prepare_files_and_flags
()
t
=
tm
.
TransformerTask
(
FLAGS
)
t
=
tm
.
TransformerTask
(
FLAGS
)
t
.
eval
()
t
.
eval
()
...
...
official/transformer/v2/translate.py
View file @
bf748370
...
@@ -18,11 +18,12 @@ from __future__ import absolute_import
...
@@ -18,11 +18,12 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
values
from
official.transformer.utils
import
tokenizer
from
official.transformer.utils
import
tokenizer
_DECODE_BATCH_SIZE
=
32
_EXTRA_DECODE_LENGTH
=
100
_EXTRA_DECODE_LENGTH
=
100
_BEAM_SIZE
=
4
_BEAM_SIZE
=
4
_ALPHA
=
0.6
_ALPHA
=
0.6
...
@@ -68,23 +69,31 @@ def _trim_and_decode(ids, subtokenizer):
...
@@ -68,23 +69,31 @@ def _trim_and_decode(ids, subtokenizer):
return
subtokenizer
.
decode
(
ids
)
return
subtokenizer
.
decode
(
ids
)
def
translate_file
(
def
translate_file
(
model
,
model
,
subtokenizer
,
input_file
,
output_file
=
None
,
params
,
print_all_translations
=
True
):
subtokenizer
,
input_file
,
output_file
=
None
,
print_all_translations
=
True
,
distribution_strategy
=
None
):
"""Translate lines in file, and save to output file if specified.
"""Translate lines in file, and save to output file if specified.
Args:
Args:
model: Keras model used to generate the translations.
model: A Keras model, used to generate the translations.
subtokenizer: Subtokenizer object for encoding and decoding source and
params: A dictionary, containing the translation related parameters.
translated lines.
subtokenizer: A subtokenizer object, used for encoding and decoding source
input_file: file containing lines to translate
and translated lines.
output_file: file that stores the generated translations.
input_file: A file containing lines to translate.
print_all_translations: If true, all translations are printed to stdout.
output_file: A file that stores the generated translations.
print_all_translations: A bool. If true, all translations are printed to
stdout.
distribution_strategy: A distribution strategy, used to perform inference
directly with tf.function instead of Keras model.predict().
Raises:
Raises:
ValueError: if output file is invalid.
ValueError: if output file is invalid.
"""
"""
batch_size
=
_DECODE_BATCH_SIZE
batch_size
=
params
[
"decode_batch_size"
]
# Read and sort inputs by length. Keep dictionary (original index-->new index
# Read and sort inputs by length. Keep dictionary (original index-->new index
# in sorted list) to write translations in the original order.
# in sorted list) to write translations in the original order.
...
@@ -101,24 +110,59 @@ def translate_file(
...
@@ -101,24 +110,59 @@ def translate_file(
if
j
+
i
*
batch_size
<
total_samples
if
j
+
i
*
batch_size
<
total_samples
]
]
lines
=
[
_encode_and_add_eos
(
l
,
subtokenizer
)
for
l
in
lines
]
lines
=
[
_encode_and_add_eos
(
l
,
subtokenizer
)
for
l
in
lines
]
if
distribution_strategy
:
for
j
in
range
(
batch_size
-
len
(
lines
)):
lines
.
append
([
tokenizer
.
EOS_ID
])
batch
=
tf
.
keras
.
preprocessing
.
sequence
.
pad_sequences
(
batch
=
tf
.
keras
.
preprocessing
.
sequence
.
pad_sequences
(
lines
,
dtype
=
"int64"
,
padding
=
"post"
)
lines
,
maxlen
=
params
[
"decode_max_length"
],
dtype
=
"int32"
,
padding
=
"post"
)
tf
.
compat
.
v1
.
logging
.
info
(
"Decoding batch %d out of %d."
,
i
,
tf
.
compat
.
v1
.
logging
.
info
(
"Decoding batch %d out of %d."
,
i
,
num_decode_batches
)
num_decode_batches
)
yield
batch
yield
batch
@
tf
.
function
def
predict_step
(
inputs
):
"""Decoding step function for TPU runs."""
def
_step_fn
(
inputs
):
"""Per replica step function."""
val_outputs
,
_
=
model
([
inputs
],
training
=
False
)
return
val_outputs
return
distribution_strategy
.
experimental_run_v2
(
_step_fn
,
args
=
(
inputs
,))
translations
=
[]
translations
=
[]
if
distribution_strategy
:
num_replicas
=
distribution_strategy
.
num_replicas_in_sync
local_batch_size
=
params
[
"decode_batch_size"
]
//
num_replicas
for
i
,
text
in
enumerate
(
input_generator
()):
for
i
,
text
in
enumerate
(
input_generator
()):
val_outputs
,
_
=
model
.
predict
(
text
)
if
distribution_strategy
:
text
=
np
.
reshape
(
text
,
[
num_replicas
,
local_batch_size
,
-
1
])
text
=
[
tf
.
convert_to_tensor
(
per_replica_text
)
for
per_replica_text
in
text
]
# pylint: disable=protected-access
text
=
values
.
PerReplica
(
distribution_strategy
.
extended
.
_device_map
,
text
)
# pylint: enable=protected-access
val_outputs
=
distribution_strategy
.
experimental_local_results
(
predict_step
(
text
))
val_outputs
=
np
.
reshape
(
[
val_output
.
numpy
()
for
val_output
in
val_outputs
],
[
params
[
"decode_batch_size"
],
-
1
])
else
:
val_outputs
,
_
=
model
.
predict
(
text
)
length
=
len
(
val_outputs
)
length
=
len
(
val_outputs
)
for
j
in
range
(
length
):
for
j
in
range
(
length
):
translation
=
_trim_and_decode
(
val_outputs
[
j
],
subtokenizer
)
if
j
+
i
*
batch_size
<
total_samples
:
translations
.
append
(
translation
)
translation
=
_trim_and_decode
(
val_outputs
[
j
],
subtokenizer
)
if
print_all_translations
:
translations
.
append
(
translation
)
tf
.
compat
.
v1
.
logging
.
info
(
if
print_all_translations
:
"Translating:
\n\t
Input: %s
\n\t
Output: %s"
%
tf
.
compat
.
v1
.
logging
.
info
(
(
sorted_inputs
[
j
+
i
*
batch_size
],
translation
))
"Translating:
\n\t
Input: %s
\n\t
Output: %s"
%
(
sorted_inputs
[
j
+
i
*
batch_size
],
translation
))
# Write translations in the order they appeared in the original file.
# Write translations in the order they appeared in the original file.
if
output_file
is
not
None
:
if
output_file
is
not
None
:
...
...
official/utils/flags/_distribution.py
0 → 100644
View file @
bf748370
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags related to distributed execution."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
import
tensorflow
as
tf
from
official.utils.flags._conventions
import
help_wrap
def
define_distribution
(
worker_hosts
=
True
,
task_index
=
True
):
"""Register distributed execution flags.
Args:
worker_hosts: Create a flag for specifying comma-separated list of workers.
task_index: Create a flag for specifying index of task.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
worker_hosts
:
flags
.
DEFINE_string
(
name
=
'worker_hosts'
,
default
=
None
,
help
=
help_wrap
(
'Comma-separated list of worker ip:port pairs for running '
'multi-worker models with DistributionStrategy. The user would '
'start the program on each host with identical value for this '
'flag.'
))
if
task_index
:
flags
.
DEFINE_integer
(
name
=
'task_index'
,
default
=-
1
,
help
=
help_wrap
(
'If multi-worker training, the task_index of this '
'worker.'
))
return
key_flags
official/utils/flags/_performance.py
View file @
bf748370
...
@@ -53,8 +53,8 @@ def get_loss_scale(flags_obj, default_for_fp16):
...
@@ -53,8 +53,8 @@ def get_loss_scale(flags_obj, default_for_fp16):
return
default_for_fp16
return
default_for_fp16
def
define_performance
(
num_parallel_calls
=
Tru
e
,
inter_op
=
Tru
e
,
intra_op
=
Tru
e
,
def
define_performance
(
num_parallel_calls
=
Fals
e
,
inter_op
=
Fals
e
,
intra_op
=
Fals
e
,
synthetic_data
=
True
,
max_train_steps
=
Tru
e
,
dtype
=
True
,
synthetic_data
=
True
,
max_train_steps
=
Fals
e
,
dtype
=
True
,
all_reduce_alg
=
True
,
num_packs
=
True
,
all_reduce_alg
=
True
,
num_packs
=
True
,
tf_gpu_thread_mode
=
False
,
tf_gpu_thread_mode
=
False
,
datasets_num_private_threads
=
False
,
datasets_num_private_threads
=
False
,
...
...
official/utils/flags/core.py
View file @
bf748370
...
@@ -32,6 +32,7 @@ from official.utils.flags import _base
...
@@ -32,6 +32,7 @@ from official.utils.flags import _base
from
official.utils.flags
import
_benchmark
from
official.utils.flags
import
_benchmark
from
official.utils.flags
import
_conventions
from
official.utils.flags
import
_conventions
from
official.utils.flags
import
_device
from
official.utils.flags
import
_device
from
official.utils.flags
import
_distribution
from
official.utils.flags
import
_misc
from
official.utils.flags
import
_misc
from
official.utils.flags
import
_performance
from
official.utils.flags
import
_performance
...
@@ -77,6 +78,8 @@ define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
...
@@ -77,6 +78,8 @@ define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
define_device
=
register_key_flags_in_core
(
_device
.
define_device
)
define_device
=
register_key_flags_in_core
(
_device
.
define_device
)
define_image
=
register_key_flags_in_core
(
_misc
.
define_image
)
define_image
=
register_key_flags_in_core
(
_misc
.
define_image
)
define_performance
=
register_key_flags_in_core
(
_performance
.
define_performance
)
define_performance
=
register_key_flags_in_core
(
_performance
.
define_performance
)
define_distribution
=
register_key_flags_in_core
(
_distribution
.
define_distribution
)
help_wrap
=
_conventions
.
help_wrap
help_wrap
=
_conventions
.
help_wrap
...
...
official/utils/flags/flags_test.py
View file @
bf748370
...
@@ -23,7 +23,9 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
...
@@ -23,7 +23,9 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def
define_flags
():
def
define_flags
():
flags_core
.
define_base
(
num_gpu
=
False
)
flags_core
.
define_base
(
num_gpu
=
False
)
flags_core
.
define_performance
(
dynamic_loss_scale
=
True
,
loss_scale
=
True
)
flags_core
.
define_performance
(
num_parallel_calls
=
True
,
inter_op
=
True
,
intra_op
=
True
,
dynamic_loss_scale
=
True
,
loss_scale
=
True
)
flags_core
.
define_image
()
flags_core
.
define_image
()
flags_core
.
define_benchmark
()
flags_core
.
define_benchmark
()
...
...
official/utils/misc/distribution_utils.py
View file @
bf748370
...
@@ -127,10 +127,7 @@ def get_distribution_strategy(distribution_strategy="default",
...
@@ -127,10 +127,7 @@ def get_distribution_strategy(distribution_strategy="default",
return
None
return
None
if
distribution_strategy
==
"tpu"
:
if
distribution_strategy
==
"tpu"
:
if
not
tpu_address
:
# When tpu_address is an empty string, we communicate with local TPUs.
raise
ValueError
(
"`tpu_address` must be specified when using "
"TPUStrategy."
)
# Initialize TPU System.
# Initialize TPU System.
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
tpu_address
)
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
tpu_address
)
return
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
return
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
...
@@ -205,38 +202,64 @@ class SyntheticDataset(object):
...
@@ -205,38 +202,64 @@ class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device."""
"""A dataset that generates synthetic data on each device."""
def
__init__
(
self
,
dataset
,
split_by
=
1
):
def
__init__
(
self
,
dataset
,
split_by
=
1
):
self
.
_input_data
=
{}
# dataset.take(1) doesn't have GPU kernel.
# dataset.take(1) doesn't have GPU kernel.
with
tf
.
device
(
'device:CPU:0'
):
with
tf
.
device
(
'device:CPU:0'
):
tensor
=
tf
.
data
.
experimental
.
get_single_element
(
dataset
.
take
(
1
))
tensor
=
tf
.
data
.
experimental
.
get_single_element
(
dataset
.
take
(
1
))
flat_tensor
=
tf
.
nest
.
flatten
(
tensor
)
flat_tensor
=
tf
.
nest
.
flatten
(
tensor
)
variable_data
=
[]
variable_data
=
[]
self
.
_
initializers
=
[]
initializers
=
[]
for
t
in
flat_tensor
:
for
t
in
flat_tensor
:
rebatched_t
=
tf
.
split
(
t
,
num_or_size_splits
=
split_by
,
axis
=
0
)[
0
]
rebatched_t
=
tf
.
split
(
t
,
num_or_size_splits
=
split_by
,
axis
=
0
)[
0
]
assert
rebatched_t
.
shape
.
is_fully_defined
(),
rebatched_t
.
shape
assert
rebatched_t
.
shape
.
is_fully_defined
(),
rebatched_t
.
shape
v
=
tf
.
compat
.
v1
.
get_local_variable
(
self
.
random_name
(),
v
=
tf
.
compat
.
v1
.
get_local_variable
(
self
.
_
random_name
(),
initializer
=
rebatched_t
)
initializer
=
rebatched_t
)
variable_data
.
append
(
v
)
variable_data
.
append
(
v
)
self
.
_initializers
.
append
(
v
.
initializer
)
initializers
.
append
(
v
.
initializer
)
self
.
_input_data
=
tf
.
nest
.
pack_sequence_as
(
tensor
,
variable_data
)
input_data
=
tf
.
nest
.
pack_sequence_as
(
tensor
,
variable_data
)
self
.
_iterator
=
SyntheticIterator
(
input_data
,
initializers
)
def
_random_name
(
self
,
size
=
10
,
chars
=
string
.
ascii_uppercase
+
string
.
digits
):
return
''
.
join
(
random
.
choice
(
chars
)
for
_
in
range
(
size
))
def
__iter__
(
self
):
return
self
.
_iterator
def
make_one_shot_iterator
(
self
):
return
self
.
_iterator
def
make_initializable_iterator
(
self
):
return
self
.
_iterator
class
SyntheticIterator
(
object
):
"""A dataset that generates synthetic data on each device."""
def
__init__
(
self
,
input_data
,
initializers
):
self
.
_input_data
=
input_data
self
.
_initializers
=
initializers
def
get_next
(
self
):
def
get_next
(
self
):
return
self
.
_input_data
return
self
.
_input_data
def
next
(
self
):
return
self
.
__next__
()
def
__next__
(
self
):
try
:
return
self
.
get_next
()
except
tf
.
errors
.
OutOfRangeError
:
raise
StopIteration
def
initialize
(
self
):
def
initialize
(
self
):
if
tf
.
executing_eagerly
():
if
tf
.
executing_eagerly
():
return
tf
.
no_op
()
return
tf
.
no_op
()
else
:
else
:
return
self
.
_initializers
return
self
.
_initializers
def
random_name
(
self
,
size
=
10
,
chars
=
string
.
ascii_uppercase
+
string
.
digits
):
return
''
.
join
(
random
.
choice
(
chars
)
for
_
in
range
(
size
))
def
_monkey_patch_dataset_method
(
strategy
):
def
_monkey_patch_dataset_method
(
strategy
):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def
make_dataset
_iterator
(
self
,
dataset
):
def
make_dataset
(
self
,
dataset
):
tf
.
compat
.
v1
.
logging
.
info
(
'Using pure synthetic data.'
)
tf
.
compat
.
v1
.
logging
.
info
(
'Using pure synthetic data.'
)
with
self
.
scope
():
with
self
.
scope
():
if
self
.
extended
.
_global_batch_size
:
# pylint: disable=protected-access
if
self
.
extended
.
_global_batch_size
:
# pylint: disable=protected-access
...
@@ -244,22 +267,34 @@ def _monkey_patch_dataset_method(strategy):
...
@@ -244,22 +267,34 @@ def _monkey_patch_dataset_method(strategy):
else
:
else
:
return
SyntheticDataset
(
dataset
)
return
SyntheticDataset
(
dataset
)
strategy
.
org_make_dataset_iterator
=
strategy
.
make_dataset_iterator
def
make_iterator
(
self
,
dataset
):
strategy
.
make_dataset_iterator
=
make_dataset_iterator
dist_dataset
=
make_dataset
(
self
,
dataset
)
return
iter
(
dist_dataset
)
strategy
.
orig_make_dataset_iterator
=
strategy
.
make_dataset_iterator
strategy
.
make_dataset_iterator
=
make_iterator
strategy
.
orig_distribute_dataset
=
strategy
.
experimental_distribute_dataset
strategy
.
experimental_distribute_dataset
=
make_dataset
def
_undo_monkey_patch_dataset_method
(
strategy
):
def
_undo_monkey_patch_dataset_method
(
strategy
):
if
hasattr
(
strategy
,
'org_make_dataset_iterator'
):
if
hasattr
(
strategy
,
'orig_make_dataset_iterator'
):
strategy
.
make_dataset_iterator
=
strategy
.
org_make_dataset_iterator
strategy
.
make_dataset_iterator
=
strategy
.
orig_make_dataset_iterator
if
hasattr
(
strategy
,
'orig_distribute_dataset'
):
strategy
.
make_dataset_iterator
=
strategy
.
orig_distribute_dataset
def
set_up_synthetic_data
():
def
set_up_synthetic_data
():
_monkey_patch_dataset_method
(
tf
.
distribute
.
OneDeviceStrategy
)
_monkey_patch_dataset_method
(
tf
.
distribute
.
OneDeviceStrategy
)
_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_monkey_patch_dataset_method
(
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if
hasattr
(
tf
,
'contrib'
):
if
hasattr
(
tf
,
'contrib'
):
_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
MirroredStrategy
)
_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
MirroredStrategy
)
_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
OneDeviceStrategy
)
_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
OneDeviceStrategy
)
_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
CollectiveAllReduceStrategy
)
else
:
else
:
print
(
'Contrib missing: Skip monkey patch tf.contrib.distribute.*'
)
print
(
'Contrib missing: Skip monkey patch tf.contrib.distribute.*'
)
...
@@ -267,10 +302,14 @@ def set_up_synthetic_data():
...
@@ -267,10 +302,14 @@ def set_up_synthetic_data():
def
undo_set_up_synthetic_data
():
def
undo_set_up_synthetic_data
():
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
OneDeviceStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
OneDeviceStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if
hasattr
(
tf
,
'contrib'
):
if
hasattr
(
tf
,
'contrib'
):
_undo_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
MirroredStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
MirroredStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
OneDeviceStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
OneDeviceStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
contrib
.
distribute
.
CollectiveAllReduceStrategy
)
else
:
else
:
print
(
'Contrib missing: Skip remove monkey patch tf.contrib.distribute.*'
)
print
(
'Contrib missing: Skip remove monkey patch tf.contrib.distribute.*'
)
...
...
official/utils/testing/integration.py
View file @
bf748370
...
@@ -29,7 +29,7 @@ from absl import flags
...
@@ -29,7 +29,7 @@ from absl import flags
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
def
run_synthetic
(
main
,
tmp_root
,
extra_flags
=
None
,
synth
=
True
,
max_train
=
1
):
def
run_synthetic
(
main
,
tmp_root
,
extra_flags
=
None
,
synth
=
True
):
"""Performs a minimal run of a model.
"""Performs a minimal run of a model.
This function is intended to test for syntax errors throughout a model. A
This function is intended to test for syntax errors throughout a model. A
...
@@ -41,7 +41,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
...
@@ -41,7 +41,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
tmp_root: Root path for the temp directory created by the test class.
tmp_root: Root path for the temp directory created by the test class.
extra_flags: Additional flags passed by the caller of this function.
extra_flags: Additional flags passed by the caller of this function.
synth: Use synthetic data.
synth: Use synthetic data.
max_train: Maximum number of allowed training steps.
"""
"""
extra_flags
=
[]
if
extra_flags
is
None
else
extra_flags
extra_flags
=
[]
if
extra_flags
is
None
else
extra_flags
...
@@ -54,9 +53,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
...
@@ -54,9 +53,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
if
synth
:
if
synth
:
args
.
append
(
"--use_synthetic_data"
)
args
.
append
(
"--use_synthetic_data"
)
if
max_train
is
not
None
:
args
.
extend
([
"--max_train_steps"
,
str
(
max_train
)])
try
:
try
:
flags_core
.
parse_flags
(
argv
=
args
)
flags_core
.
parse_flags
(
argv
=
args
)
main
(
flags
.
FLAGS
)
main
(
flags
.
FLAGS
)
...
...
official/utils/testing/pylint.rcfile
View file @
bf748370
[MESSAGES CONTROL]
[MESSAGES CONTROL]
disable=R,W,bad-option-value,trailing-newlines
disable=R,W,bad-option-value,trailing-newlines
,no-name-in-module
[REPORTS]
[REPORTS]
# Tells whether to display a full report or only the messages
# Tells whether to display a full report or only the messages
...
...
official/
resnet/keras
/README.md
→
official/
vision/image_classification
/README.md
View file @
bf748370
This folder contains the Keras implementation of the ResNet models. For more
This folder contains the Keras implementation of the ResNet models. For more
information about the models, please refer to this
[
README file
](
../README.md
)
.
information about the models, please refer to this
[
README file
](
../
../
README.md
)
.
Similar to the
[
estimator implementation
](
/official
/resnet
)
, the Keras
Similar to the
[
estimator implementation
](
../../r1
/resnet
)
, the Keras
implementation has code for both CIFAR-10 data and ImageNet data. The CIFAR-10
implementation has code for both CIFAR-10 data and ImageNet data. The CIFAR-10
version uses a ResNet56 model implemented in
version uses a ResNet56 model implemented in
[
`resnet_cifar_model.py`
](
./resnet_cifar_model.py
)
, and the ImageNet version
[
`resnet_cifar_model.py`
](
./resnet_cifar_model.py
)
, and the ImageNet version
uses a ResNet50 model implemented in
[
`resnet_model.py`
](
./resnet_model.py
)
.
uses a ResNet50 model implemented in
[
`resnet_model.py`
](
./resnet_model.py
)
.
To use
To use
either dataset, make sure that you have the latest version of TensorFlow
either dataset, make sure that you have the latest version of TensorFlow
installed and
installed and
[
add the models folder to your Python path
](
/official/#running-the-models
)
,
[
add the models folder to your Python path
](
/official/#running-the-models
)
,
otherwise you may encounter an error like
`ImportError: No module named
otherwise you may encounter an error like
`ImportError: No module named
official.resnet`
.
official.resnet`
.
## CIFAR-10
## CIFAR-10
Download and extract the CIFAR-10 data. You can use the following script:
Download and extract the CIFAR-10 data. You can use the following script:
```
bash
```
bash
python cifar10_download_and_extract.py
python
../../r1/resnet/
cifar10_download_and_extract.py
```
```
After you download the data, you can run the program by:
After you download the data, you can run the program by:
```
bash
```
bash
python
keras
_cifar_main.py
python
resnet
_cifar_main.py
```
```
If you did not use the default directory to download the data, specify the
If you did not use the default directory to download the data, specify the
location with the
`--data_dir`
flag, like:
location with the
`--data_dir`
flag, like:
```
bash
```
bash
python
keras
_cifar_main.py
--data_dir
=
/path/to/cifar
python
resnet
_cifar_main.py
--data_dir
=
/path/to/cifar
```
```
## ImageNet
## ImageNet
Download the ImageNet dataset and convert it to TFRecord format.
Download the ImageNet dataset and convert it to TFRecord format.
The following
[
script
](
https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py
)
The following
[
script
](
https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py
)
and
[
README
](
https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy
)
and
[
README
](
https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy
)
provide a few options.
provide a few options.
...
@@ -44,57 +44,57 @@ provide a few options.
...
@@ -44,57 +44,57 @@ provide a few options.
Once your dataset is ready, you can begin training the model as follows:
Once your dataset is ready, you can begin training the model as follows:
```
bash
```
bash
python
keras
_imagenet_main.py
python
resnet
_imagenet_main.py
```
```
Again, if you did not download the data to the default directory, specify the
Again, if you did not download the data to the default directory, specify the
location with the
`--data_dir`
flag:
location with the
`--data_dir`
flag:
```
bash
```
bash
python
keras
_imagenet_main.py
--data_dir
=
/path/to/imagenet
python
resnet
_imagenet_main.py
--data_dir
=
/path/to/imagenet
```
```
There are more flag options you can specify. Here are some examples:
There are more flag options you can specify. Here are some examples:
-
`--use_synthetic_data`
: when set to true, synthetic data, rather than real
-
`--use_synthetic_data`
: when set to true, synthetic data, rather than real
data, are used;
data, are used;
-
`--batch_size`
: the batch size used for the model;
-
`--batch_size`
: the batch size used for the model;
-
`--model_dir`
: the directory to save the model checkpoint;
-
`--model_dir`
: the directory to save the model checkpoint;
-
`--train_epochs`
: number of epoches to run for training the model;
-
`--train_epochs`
: number of epoches to run for training the model;
-
`--train_steps`
: number of steps to run for training the model. We now only
-
`--train_steps`
: number of steps to run for training the model. We now only
support a number that is smaller than the number of batches in an epoch.
support a number that is smaller than the number of batches in an epoch.
-
`--skip_eval`
: when set to true, evaluation as well as validation during
-
`--skip_eval`
: when set to true, evaluation as well as validation during
training is skipped
training is skipped
For example, this is a typical command line to run with ImageNet data with
For example, this is a typical command line to run with ImageNet data with
batch size 128 per GPU:
batch size 128 per GPU:
```
bash
```
bash
python
-m
keras
_imagenet_main
\
python
-m
resnet
_imagenet_main
\
--model_dir
=
/tmp/model_dir/something
\
--model_dir
=
/tmp/model_dir/something
\
--num_gpus
=
2
\
--num_gpus
=
2
\
--batch_size
=
128
\
--batch_size
=
128
\
--train_epochs
=
90
\
--train_epochs
=
90
\
--train_steps
=
10
\
--train_steps
=
10
\
--use_synthetic_data
=
false
--use_synthetic_data
=
false
```
```
See
[
`
keras_
common.py`
](
keras_
common.py
)
for full list of options.
See
[
`common.py`
](
common.py
)
for full list of options.
## Using multiple GPUs
## Using multiple GPUs
You can train these models on multiple GPUs using
`tf.distribute.Strategy`
API.
You can train these models on multiple GPUs using
`tf.distribute.Strategy`
API.
You can read more about them in this
You can read more about them in this
[
guide
](
https://www.tensorflow.org/guide/distribute_strategy
)
.
[
guide
](
https://www.tensorflow.org/guide/distribute_strategy
)
.
In this example, we have made it easier to use is with just a command line flag
In this example, we have made it easier to use is with just a command line flag
`--num_gpus`
. By default this flag is 1 if TensorFlow is compiled with CUDA,
`--num_gpus`
. By default this flag is 1 if TensorFlow is compiled with CUDA,
and 0 otherwise.
and 0 otherwise.
-
--num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device.
-
--num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device.
-
--num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device.
-
--num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device.
-
--num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous
-
--num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous
distributed training across the GPUs.
distributed training across the GPUs.
If you wish to run without
`tf.distribute.Strategy`
, you can do so by setting
If you wish to run without
`tf.distribute.Strategy`
, you can do so by setting
`--distribution_strategy=off`
.
`--distribution_strategy=off`
.
official/
wide_deep
/__init__.py
→
official/
vision/image_classification
/__init__.py
View file @
bf748370
File moved
official/
resnet/keras
/cifar_preprocessing.py
→
official/
vision/image_classification
/cifar_preprocessing.py
View file @
bf748370
...
@@ -22,7 +22,7 @@ import os
...
@@ -22,7 +22,7 @@ import os
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.
resnet.keras
import
imagenet_preprocessing
from
official.
vision.image_classification
import
imagenet_preprocessing
HEIGHT
=
32
HEIGHT
=
32
WIDTH
=
32
WIDTH
=
32
...
...
official/
resnet/keras/keras_
common.py
→
official/
vision/image_classification/
common.py
View file @
bf748370
...
@@ -20,17 +20,13 @@ from __future__ import print_function
...
@@ -20,17 +20,13 @@ from __future__ import print_function
import
multiprocessing
import
multiprocessing
import
os
import
os
import
numpy
as
np
# pylint: disable=g-bad-import-order
from
absl
import
flags
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.keras.optimizer_v2
import
gradient_descent
as
gradient_descent_v2
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
# pylint: disable=ungrouped-imports
from
tensorflow.python.keras.optimizer_v2
import
(
gradient_descent
as
gradient_descent_v2
)
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
BASE_LEARNING_RATE
=
0.1
# This matches Jing's version.
BASE_LEARNING_RATE
=
0.1
# This matches Jing's version.
...
@@ -262,6 +258,7 @@ def define_keras_flags(dynamic_loss_scale=True):
...
@@ -262,6 +258,7 @@ def define_keras_flags(dynamic_loss_scale=True):
force_v2_in_keras_compile
=
True
)
force_v2_in_keras_compile
=
True
)
flags_core
.
define_image
()
flags_core
.
define_image
()
flags_core
.
define_benchmark
()
flags_core
.
define_benchmark
()
flags_core
.
define_distribution
()
flags
.
adopt_module_key_flags
(
flags_core
)
flags
.
adopt_module_key_flags
(
flags_core
)
flags
.
DEFINE_boolean
(
name
=
'enable_eager'
,
default
=
False
,
help
=
'Enable eager?'
)
flags
.
DEFINE_boolean
(
name
=
'enable_eager'
,
default
=
False
,
help
=
'Enable eager?'
)
...
...
official/
resnet/keras/keras_
common_test.py
→
official/
vision/image_classification/
common_test.py
View file @
bf748370
...
@@ -12,21 +12,21 @@
...
@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Tests for the
keras_
common module."""
"""Tests for the common module."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
print_function
from
mock
import
Mock
from
mock
import
Mock
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
from
tensorflow.python.platform
import
googletest
from
official.resnet.keras
import
keras_common
from
tensorflow.python.platform
import
googletest
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.vision.image_classification
import
common
class
KerasCommonTests
(
tf
.
test
.
TestCase
):
class
KerasCommonTests
(
tf
.
test
.
TestCase
):
"""Tests for
keras_
common."""
"""Tests for common."""
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
...
@@ -42,7 +42,7 @@ class KerasCommonTests(tf.test.TestCase):
...
@@ -42,7 +42,7 @@ class KerasCommonTests(tf.test.TestCase):
keras_utils
.
BatchTimestamp
(
1
,
2
),
keras_utils
.
BatchTimestamp
(
1
,
2
),
keras_utils
.
BatchTimestamp
(
2
,
3
)]
keras_utils
.
BatchTimestamp
(
2
,
3
)]
th
.
train_finish_time
=
12345
th
.
train_finish_time
=
12345
stats
=
keras_
common
.
build_stats
(
history
,
eval_output
,
[
th
])
stats
=
common
.
build_stats
(
history
,
eval_output
,
[
th
])
self
.
assertEqual
(
1.145
,
stats
[
'loss'
])
self
.
assertEqual
(
1.145
,
stats
[
'loss'
])
self
.
assertEqual
(.
99988
,
stats
[
'training_accuracy_top_1'
])
self
.
assertEqual
(.
99988
,
stats
[
'training_accuracy_top_1'
])
...
@@ -57,7 +57,7 @@ class KerasCommonTests(tf.test.TestCase):
...
@@ -57,7 +57,7 @@ class KerasCommonTests(tf.test.TestCase):
history
=
self
.
_build_history
(
1.145
,
cat_accuracy_sparse
=
.
99988
)
history
=
self
.
_build_history
(
1.145
,
cat_accuracy_sparse
=
.
99988
)
eval_output
=
self
.
_build_eval_output
(.
928
,
1.9844
)
eval_output
=
self
.
_build_eval_output
(.
928
,
1.9844
)
stats
=
keras_
common
.
build_stats
(
history
,
eval_output
,
None
)
stats
=
common
.
build_stats
(
history
,
eval_output
,
None
)
self
.
assertEqual
(
1.145
,
stats
[
'loss'
])
self
.
assertEqual
(
1.145
,
stats
[
'loss'
])
self
.
assertEqual
(.
99988
,
stats
[
'training_accuracy_top_1'
])
self
.
assertEqual
(.
99988
,
stats
[
'training_accuracy_top_1'
])
...
...
official/
resnet/keras
/imagenet_preprocessing.py
→
official/
vision/image_classification
/imagenet_preprocessing.py
View file @
bf748370
File moved
official/
resnet/keras/keras
_cifar_main.py
→
official/
vision/image_classification/resnet
_cifar_main.py
View file @
bf748370
...
@@ -22,13 +22,13 @@ from absl import app as absl_app
...
@@ -22,13 +22,13 @@ from absl import app as absl_app
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.resnet.keras
import
cifar_preprocessing
from
official.resnet.keras
import
keras_common
from
official.resnet.keras
import
resnet_cifar_model
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.vision.image_classification
import
cifar_preprocessing
from
official.vision.image_classification
import
common
from
official.vision.image_classification
import
resnet_cifar_model
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
...
@@ -55,7 +55,7 @@ def learning_rate_schedule(current_epoch,
...
@@ -55,7 +55,7 @@ def learning_rate_schedule(current_epoch,
Adjusted learning rate.
Adjusted learning rate.
"""
"""
del
current_batch
,
batches_per_epoch
# not used
del
current_batch
,
batches_per_epoch
# not used
initial_learning_rate
=
keras_
common
.
BASE_LEARNING_RATE
*
batch_size
/
128
initial_learning_rate
=
common
.
BASE_LEARNING_RATE
*
batch_size
/
128
learning_rate
=
initial_learning_rate
learning_rate
=
initial_learning_rate
for
mult
,
start_epoch
in
LR_SCHEDULE
:
for
mult
,
start_epoch
in
LR_SCHEDULE
:
if
current_epoch
>=
start_epoch
:
if
current_epoch
>=
start_epoch
:
...
@@ -83,8 +83,8 @@ def run(flags_obj):
...
@@ -83,8 +83,8 @@ def run(flags_obj):
# Execute flag override logic for better model performance
# Execute flag override logic for better model performance
if
flags_obj
.
tf_gpu_thread_mode
:
if
flags_obj
.
tf_gpu_thread_mode
:
keras_
common
.
set_gpu_thread_mode_and_count
(
flags_obj
)
common
.
set_gpu_thread_mode_and_count
(
flags_obj
)
keras_
common
.
set_cudnn_batchnorm_mode
()
common
.
set_cudnn_batchnorm_mode
()
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
dtype
==
'fp16'
:
if
dtype
==
'fp16'
:
...
@@ -116,7 +116,7 @@ def run(flags_obj):
...
@@ -116,7 +116,7 @@ def run(flags_obj):
if
flags_obj
.
use_synthetic_data
:
if
flags_obj
.
use_synthetic_data
:
distribution_utils
.
set_up_synthetic_data
()
distribution_utils
.
set_up_synthetic_data
()
input_fn
=
keras_
common
.
get_synth_input_fn
(
input_fn
=
common
.
get_synth_input_fn
(
height
=
cifar_preprocessing
.
HEIGHT
,
height
=
cifar_preprocessing
.
HEIGHT
,
width
=
cifar_preprocessing
.
WIDTH
,
width
=
cifar_preprocessing
.
WIDTH
,
num_channels
=
cifar_preprocessing
.
NUM_CHANNELS
,
num_channels
=
cifar_preprocessing
.
NUM_CHANNELS
,
...
@@ -150,7 +150,7 @@ def run(flags_obj):
...
@@ -150,7 +150,7 @@ def run(flags_obj):
parse_record_fn
=
cifar_preprocessing
.
parse_record
)
parse_record_fn
=
cifar_preprocessing
.
parse_record
)
with
strategy_scope
:
with
strategy_scope
:
optimizer
=
keras_
common
.
get_optimizer
()
optimizer
=
common
.
get_optimizer
()
model
=
resnet_cifar_model
.
resnet56
(
classes
=
cifar_preprocessing
.
NUM_CLASSES
)
model
=
resnet_cifar_model
.
resnet56
(
classes
=
cifar_preprocessing
.
NUM_CLASSES
)
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
...
@@ -171,7 +171,7 @@ def run(flags_obj):
...
@@ -171,7 +171,7 @@ def run(flags_obj):
if
flags_obj
.
report_accuracy_metrics
else
None
),
if
flags_obj
.
report_accuracy_metrics
else
None
),
run_eagerly
=
flags_obj
.
run_eagerly
)
run_eagerly
=
flags_obj
.
run_eagerly
)
callbacks
=
keras_
common
.
get_callbacks
(
callbacks
=
common
.
get_callbacks
(
learning_rate_schedule
,
cifar_preprocessing
.
NUM_IMAGES
[
'train'
])
learning_rate_schedule
,
cifar_preprocessing
.
NUM_IMAGES
[
'train'
])
train_steps
=
cifar_preprocessing
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
train_steps
=
cifar_preprocessing
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
...
@@ -216,12 +216,12 @@ def run(flags_obj):
...
@@ -216,12 +216,12 @@ def run(flags_obj):
if
not
strategy
and
flags_obj
.
explicit_gpu_placement
:
if
not
strategy
and
flags_obj
.
explicit_gpu_placement
:
no_dist_strat_device
.
__exit__
()
no_dist_strat_device
.
__exit__
()
stats
=
keras_
common
.
build_stats
(
history
,
eval_output
,
callbacks
)
stats
=
common
.
build_stats
(
history
,
eval_output
,
callbacks
)
return
stats
return
stats
def
define_cifar_flags
():
def
define_cifar_flags
():
keras_
common
.
define_keras_flags
(
dynamic_loss_scale
=
False
)
common
.
define_keras_flags
(
dynamic_loss_scale
=
False
)
flags_core
.
set_defaults
(
data_dir
=
'/tmp/cifar10_data/cifar-10-batches-bin'
,
flags_core
.
set_defaults
(
data_dir
=
'/tmp/cifar10_data/cifar-10-batches-bin'
,
model_dir
=
'/tmp/cifar10_model'
,
model_dir
=
'/tmp/cifar10_model'
,
...
...
official/
resnet/keras
/resnet_cifar_model.py
→
official/
vision/image_classification
/resnet_cifar_model.py
View file @
bf748370
File moved
official/
resnet/keras/keras
_cifar_test.py
→
official/
vision/image_classification/resnet
_cifar_test.py
View file @
bf748370
...
@@ -18,17 +18,16 @@ from __future__ import absolute_import
...
@@ -18,17 +18,16 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
tempfile
import
mkdtemp
import
tempfile
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.resnet.keras
import
cifar_preprocessing
from
official.resnet.keras
import
keras_cifar_main
from
official.resnet.keras
import
keras_common
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
# pylint: disable=ungrouped-imports
from
tensorflow.python.eager
import
context
from
tensorflow.python.eager
import
context
from
tensorflow.python.platform
import
googletest
from
tensorflow.python.platform
import
googletest
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.vision.image_classification
import
cifar_preprocessing
from
official.vision.image_classification
import
resnet_cifar_main
class
KerasCifarTest
(
googletest
.
TestCase
):
class
KerasCifarTest
(
googletest
.
TestCase
):
...
@@ -43,13 +42,13 @@ class KerasCifarTest(googletest.TestCase):
...
@@ -43,13 +42,13 @@ class KerasCifarTest(googletest.TestCase):
def
get_temp_dir
(
self
):
def
get_temp_dir
(
self
):
if
not
self
.
_tempdir
:
if
not
self
.
_tempdir
:
self
.
_tempdir
=
mkdtemp
(
dir
=
googletest
.
GetTempDir
())
self
.
_tempdir
=
tempfile
.
mkdtemp
(
dir
=
googletest
.
GetTempDir
())
return
self
.
_tempdir
return
self
.
_tempdir
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
KerasCifarTest
,
cls
).
setUpClass
()
super
(
KerasCifarTest
,
cls
).
setUpClass
()
keras
_cifar_main
.
define_cifar_flags
()
resnet
_cifar_main
.
define_cifar_flags
()
def
setUp
(
self
):
def
setUp
(
self
):
super
(
KerasCifarTest
,
self
).
setUp
()
super
(
KerasCifarTest
,
self
).
setUp
()
...
@@ -72,7 +71,7 @@ class KerasCifarTest(googletest.TestCase):
...
@@ -72,7 +71,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_cifar_main
.
run
,
main
=
resnet
_cifar_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -88,7 +87,7 @@ class KerasCifarTest(googletest.TestCase):
...
@@ -88,7 +87,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_cifar_main
.
run
,
main
=
resnet
_cifar_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -112,7 +111,7 @@ class KerasCifarTest(googletest.TestCase):
...
@@ -112,7 +111,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_cifar_main
.
run
,
main
=
resnet
_cifar_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -134,7 +133,7 @@ class KerasCifarTest(googletest.TestCase):
...
@@ -134,7 +133,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_cifar_main
.
run
,
main
=
resnet
_cifar_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -157,7 +156,7 @@ class KerasCifarTest(googletest.TestCase):
...
@@ -157,7 +156,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_cifar_main
.
run
,
main
=
resnet
_cifar_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -178,7 +177,7 @@ class KerasCifarTest(googletest.TestCase):
...
@@ -178,7 +177,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_cifar_main
.
run
,
main
=
resnet
_cifar_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment