Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e170a8ba
Commit
e170a8ba
authored
Aug 30, 2019
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 266413847
parent
765da424
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
70 additions
and
57 deletions
+70
-57
official/transformer/model/beam_search.py
official/transformer/model/beam_search.py
+1
-1
official/transformer/v2/data_pipeline.py
official/transformer/v2/data_pipeline.py
+13
-6
official/transformer/v2/misc.py
official/transformer/v2/misc.py
+0
-4
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+38
-38
official/transformer/v2/translate.py
official/transformer/v2/translate.py
+18
-8
No files found.
official/transformer/model/beam_search.py
View file @
e170a8ba
...
...
@@ -402,7 +402,7 @@ class SequenceBeamSearch(object):
topk_ids
=
topk_indices
%
self
.
vocab_size
if
self
.
padded_decode
:
topk_seq
=
tf
.
transpose
(
topk_seq
,
perm
=
[
2
,
0
,
1
])
topk_seq
=
tf
.
tensor_scatter_update
(
topk_seq
,
[
i
+
1
],
topk_ids
)
topk_seq
=
tf
.
tensor_scatter_
nd_
update
(
topk_seq
,
[
i
+
1
],
topk_ids
)
topk_seq
=
tf
.
transpose
(
topk_seq
,
perm
=
[
1
,
2
,
0
])
else
:
topk_ids
=
tf
.
expand_dims
(
topk_ids
,
axis
=
2
)
...
...
official/transformer/v2/data_pipeline.py
View file @
e170a8ba
...
...
@@ -54,6 +54,7 @@ from __future__ import print_function
import
math
import
os
from
absl
import
logging
import
tensorflow
as
tf
from
official.transformer.v2
import
misc
...
...
@@ -193,7 +194,7 @@ def _batch_examples(dataset, batch_size, max_length):
def
_read_and_batch_from_files
(
file_pattern
,
batch_size
,
max_length
,
num_parallel_calls
,
shuffle
,
repeat
,
static_batch
=
False
,
num_replicas
=
1
):
static_batch
=
False
,
num_replicas
=
1
,
ctx
=
None
):
"""Create dataset where each item is a dict of "inputs" and "targets".
Args:
...
...
@@ -219,12 +220,17 @@ def _read_and_batch_from_files(
batches, and each global batch is equally divisible by number of replicas.
Currently it is only effective when static_batch==True. TODO: make it
effective when static_batch=False.
ctx: Input context.
Returns:
tf.data.Dataset object containing examples loaded from the files.
"""
dataset
=
tf
.
data
.
Dataset
.
list_files
(
file_pattern
,
shuffle
=
shuffle
)
if
ctx
and
ctx
.
num_input_pipelines
>
1
:
logging
.
info
(
"Shard %d of the dataset."
,
ctx
.
input_pipeline_id
)
dataset
=
dataset
.
shard
(
ctx
.
num_input_pipelines
,
ctx
.
input_pipeline_id
)
# Read files and interleave results. When training, the order of the examples
# will be non-deterministic.
options
=
tf
.
data
.
Options
()
...
...
@@ -247,7 +253,7 @@ def _read_and_batch_from_files(
# First calculate batch size (token number) per worker, then divide it
# into sentences, and finally expand to a global batch. It could prove
# the global batch divisble for distribution strategy.
(
(
batch_size
//
num_replicas
)
//
max_length
)
*
num_replicas
,
int
(
batch_size
//
num_replicas
//
max_length
*
num_replicas
)
,
([
max_length
],
[
max_length
]),
drop_remainder
=
True
)
else
:
# Group and batch such that each batch has examples of similar length.
...
...
@@ -276,7 +282,7 @@ def _generate_synthetic_data(params):
return
dataset
.
batch
(
batch
,
drop_remainder
=
True
)
def
train_input_fn
(
params
):
def
train_input_fn
(
params
,
ctx
=
None
):
"""Load and return dataset of batched examples for use during training."""
file_pattern
=
os
.
path
.
join
(
params
[
"data_dir"
]
or
""
,
"*train*"
)
if
params
[
"use_synthetic_data"
]:
...
...
@@ -285,10 +291,10 @@ def train_input_fn(params):
file_pattern
,
params
[
"batch_size"
],
params
[
"max_length"
],
params
[
"num_parallel_calls"
],
shuffle
=
True
,
repeat
=
params
[
"repeat_dataset"
],
static_batch
=
params
[
"static_batch"
],
num_replicas
=
params
[
"num_gpus"
])
num_replicas
=
params
[
"num_gpus"
]
,
ctx
=
ctx
)
def
eval_input_fn
(
params
):
def
eval_input_fn
(
params
,
ctx
=
None
):
"""Load and return dataset of batched examples for use during evaluation."""
file_pattern
=
os
.
path
.
join
(
params
[
"data_dir"
]
or
""
,
"*dev*"
)
if
params
[
"use_synthetic_data"
]:
...
...
@@ -296,7 +302,8 @@ def eval_input_fn(params):
return
_read_and_batch_from_files
(
file_pattern
,
params
[
"batch_size"
],
params
[
"max_length"
],
params
[
"num_parallel_calls"
],
shuffle
=
False
,
repeat
=
1
,
static_batch
=
params
[
"static_batch"
],
num_replicas
=
params
[
"num_gpus"
])
static_batch
=
params
[
"static_batch"
],
num_replicas
=
params
[
"num_gpus"
],
ctx
=
ctx
)
def
map_data_for_transformer_fn
(
x
,
y
):
...
...
official/transformer/v2/misc.py
View file @
e170a8ba
...
...
@@ -182,10 +182,6 @@ def define_transformer_flags():
default
=
False
,
help
=
flags_core
.
help_wrap
(
'Whether the model runs with custom training loop.'
))
flags
.
DEFINE_bool
(
name
=
'is_tpu_pod'
,
default
=
False
,
help
=
flags_core
.
help_wrap
(
'Whether the model runs on a TPU pod.'
))
flags
.
DEFINE_bool
(
name
=
'use_tpu_2vm_config'
,
default
=
False
,
...
...
official/transformer/v2/transformer_main.py
View file @
e170a8ba
...
...
@@ -146,7 +146,6 @@ class TransformerTask(object):
params
[
"num_gpus"
]
=
num_gpus
params
[
"use_ctl"
]
=
flags_obj
.
use_ctl
params
[
"is_tpu_pod"
]
=
flags_obj
.
is_tpu_pod
params
[
"data_dir"
]
=
flags_obj
.
data_dir
params
[
"model_dir"
]
=
flags_obj
.
model_dir
params
[
"static_batch"
]
=
flags_obj
.
static_batch
...
...
@@ -210,6 +209,15 @@ class TransformerTask(object):
with
distribution_utils
.
get_strategy_scope
(
self
.
distribution_strategy
):
model
=
transformer
.
create_model
(
params
,
is_train
=
True
)
opt
=
self
.
_create_optimizer
()
current_step
=
0
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
,
optimizer
=
opt
)
latest_checkpoint
=
tf
.
train
.
latest_checkpoint
(
flags_obj
.
model_dir
)
if
latest_checkpoint
:
checkpoint
.
restore
(
latest_checkpoint
)
logging
.
info
(
"Loaded checkpoint %s"
,
latest_checkpoint
)
current_step
=
opt
.
iterations
.
numpy
()
if
params
[
"use_ctl"
]:
train_loss_metric
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
)
...
...
@@ -226,7 +234,7 @@ class TransformerTask(object):
train_ds
=
(
self
.
distribution_strategy
.
experimental_distribute_datasets_from_function
(
lambda
ctx
:
data_pipeline
.
train_input_fn
(
params
)))
lambda
ctx
:
data_pipeline
.
train_input_fn
(
params
,
ctx
)))
else
:
train_ds
=
data_pipeline
.
train_input_fn
(
params
)
map_data_fn
=
data_pipeline
.
map_data_for_transformer_fn
...
...
@@ -275,40 +283,33 @@ class TransformerTask(object):
self
.
distribution_strategy
.
experimental_run_v2
(
_step_fn
,
args
=
(
next
(
iterator
),))
if
self
.
use_tpu
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
,
optimizer
=
opt
)
latest_checkpoint
=
tf
.
train
.
latest_checkpoint
(
flags_obj
.
model_dir
)
if
latest_checkpoint
:
checkpoint
.
restore
(
latest_checkpoint
)
logging
.
info
(
"Loaded checkpoint %s"
,
latest_checkpoint
)
if
flags_obj
.
train_steps
<
flags_obj
.
steps_between_evals
:
flags_obj
.
steps_between_evals
=
flags_obj
.
train_steps
iterations
=
flags_obj
.
train_steps
//
flags_obj
.
steps_between_evals
cased_score
,
uncased_score
=
None
,
None
cased_score_history
,
uncased_score_history
=
[],
[]
for
i
in
range
(
1
,
iterations
+
1
):
print
(
"Start train iteration:{}/{}"
.
format
(
i
,
iterations
))
while
current_step
<
flags_obj
.
train_steps
:
remaining_steps
=
flags_obj
.
train_steps
-
current_step
train_steps_per_eval
=
(
remaining_steps
if
remaining_steps
<
flags_obj
.
steps_between_evals
else
flags_obj
.
steps_between_evals
)
current_iteration
=
current_step
//
flags_obj
.
steps_between_evals
print
(
"Start train iteration at global step:{}"
.
format
(
current_step
))
history
=
None
if
params
[
"use_ctl"
]:
if
not
self
.
use_tpu
:
raise
NotImplementedError
(
"Custom training loop on GPUs is not implemented."
)
train_steps_per_eval
=
tf
.
convert_to_tensor
(
flags_obj
.
steps_between_evals
,
dtype
=
tf
.
int32
)
# Runs training steps.
train_steps
(
train_ds_iterator
,
train_steps_per_eval
)
train_steps
(
train_ds_iterator
,
tf
.
convert_to_tensor
(
train_steps_per_eval
,
dtype
=
tf
.
int32
))
current_step
+=
train_steps_per_eval
train_loss
=
train_loss_metric
.
result
().
numpy
().
astype
(
float
)
logging
.
info
(
"Train Step: %d/%d / loss = %s"
,
i
*
flags_obj
.
steps_between_evals
,
flags_obj
.
train_steps
,
train_loss
)
current_step
,
flags_obj
.
train_steps
,
train_loss
)
checkpoint_name
=
checkpoint
.
save
(
os
.
path
.
join
(
flags_obj
.
model_dir
,
"ctl_step_{}.ckpt"
.
format
(
i
*
flags_obj
.
steps_between_evals
)))
"ctl_step_{}.ckpt"
.
format
(
current_step
)))
logging
.
info
(
"Saved checkpoint to %s"
,
checkpoint_name
)
else
:
if
self
.
use_tpu
:
...
...
@@ -316,24 +317,22 @@ class TransformerTask(object):
"Keras model.fit on TPUs is not implemented."
)
history
=
model
.
fit
(
train_ds
,
initial_epoch
=
i
-
1
,
epochs
=
i
,
steps_per_epoch
=
flags_obj
.
steps_between
_eval
s
,
initial_epoch
=
current_iteration
,
epochs
=
current_iteration
+
1
,
steps_per_epoch
=
train_steps_per
_eval
,
callbacks
=
callbacks
,
# If TimeHistory is enabled, progress bar would be messy. Increase
# the verbose level to get rid of it.
verbose
=
(
2
if
flags_obj
.
enable_time_history
else
1
))
current_step
+=
train_steps_per_eval
logging
.
info
(
"Train history: {}"
.
format
(
history
.
history
))
print
(
"End train iteration:{}/{} global step:{}"
.
format
(
i
,
iterations
,
i
*
flags_obj
.
steps_between_evals
))
print
(
"End train iteration at global step:{}"
.
format
(
current_step
))
if
(
flags_obj
.
bleu_source
and
flags_obj
.
bleu_ref
):
uncased_score
,
cased_score
=
self
.
eval
()
cased_score_history
.
append
([
i
,
cased_score
])
uncased_score_history
.
append
([
i
,
uncased_score
])
cased_score_history
.
append
([
current_iteration
+
1
,
cased_score
])
uncased_score_history
.
append
([
current_iteration
+
1
,
uncased_score
])
stats
=
({
"loss"
:
train_loss
...
...
@@ -347,12 +346,13 @@ class TransformerTask(object):
def
eval
(
self
):
"""Evaluates the model."""
if
not
self
.
predict_model
:
self
.
predict_model
=
transformer
.
create_model
(
self
.
params
,
False
)
self
.
_load_weights_if_possible
(
self
.
predict_model
,
tf
.
train
.
latest_checkpoint
(
self
.
flags_obj
.
model_dir
))
self
.
predict_model
.
summary
()
with
distribution_utils
.
get_strategy_scope
(
self
.
distribution_strategy
):
if
not
self
.
predict_model
:
self
.
predict_model
=
transformer
.
create_model
(
self
.
params
,
False
)
self
.
_load_weights_if_possible
(
self
.
predict_model
,
tf
.
train
.
latest_checkpoint
(
self
.
flags_obj
.
model_dir
))
self
.
predict_model
.
summary
()
return
evaluate_and_log_bleu
(
self
.
predict_model
,
self
.
params
,
self
.
flags_obj
.
bleu_source
,
self
.
flags_obj
.
bleu_ref
,
self
.
flags_obj
.
vocab_file
,
...
...
@@ -430,7 +430,7 @@ class TransformerTask(object):
# which will ensure tf.keras.mixed_precision and tf.train.experimental.enable_mixed_precision_graph_rewrite
# do not double up.
opt
=
tf
.
train
.
experimental
.
enable_mixed_precision_graph_rewrite
(
opt
)
return
opt
...
...
official/transformer/v2/translate.py
View file @
e170a8ba
...
...
@@ -128,8 +128,10 @@ def translate_file(model,
def
_step_fn
(
inputs
):
"""Per replica step function."""
val_outputs
,
_
=
model
([
inputs
],
training
=
False
)
return
val_outputs
tag
=
inputs
[
0
]
val_inputs
=
inputs
[
1
]
val_outputs
,
_
=
model
([
val_inputs
],
training
=
False
)
return
tag
,
val_outputs
return
distribution_strategy
.
experimental_run_v2
(
_step_fn
,
args
=
(
inputs
,))
...
...
@@ -140,17 +142,25 @@ def translate_file(model,
for
i
,
text
in
enumerate
(
input_generator
()):
if
distribution_strategy
:
text
=
np
.
reshape
(
text
,
[
num_replicas
,
local_batch_size
,
-
1
])
# Add tag to the input of each replica with the reordering logic after
# outputs, to ensure the output order matches the input order.
text
=
[
tf
.
convert_to_tensor
(
per_replica_text
)
for
per_replica_text
in
text
[
tf
.
convert_to_tensor
(
tag
),
tf
.
convert_to_tensor
(
per_replica_text
)]
for
tag
,
per_replica_text
in
enumerate
(
text
)
]
# pylint: disable=protected-access
text
=
values
.
PerReplica
(
distribution_strategy
.
extended
.
_device_map
,
text
)
# pylint: enable=protected-access
val_outputs
=
distribution_strategy
.
experimental_local_results
(
outputs
=
distribution_strategy
.
experimental_local_results
(
predict_step
(
text
))
val_outputs
=
np
.
reshape
(
[
val_output
.
numpy
()
for
val_output
in
val_outputs
],
[
params
[
"decode_batch_size"
],
-
1
])
tags
,
unordered_val_outputs
=
outputs
[
0
]
tags
=
[
tag
.
numpy
()
for
tag
in
tags
.
_values
]
unordered_val_outputs
=
[
val_output
.
numpy
()
for
val_output
in
unordered_val_outputs
.
_values
]
# pylint: enable=protected-access
val_outputs
=
[
None
]
*
len
(
tags
)
for
k
in
range
(
len
(
tags
)):
val_outputs
[
tags
[
k
]]
=
unordered_val_outputs
[
k
]
val_outputs
=
np
.
reshape
(
val_outputs
,
[
params
[
"decode_batch_size"
],
-
1
])
else
:
val_outputs
,
_
=
model
.
predict
(
text
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment